[clang] 37d0568 - [Clang] Introduce 'clang-nvlink-wrapper' to work around 'nvlink' (#96561)

via cfe-commits cfe-commits at lists.llvm.org
Mon Jul 22 16:20:18 PDT 2024


Author: Joseph Huber
Date: 2024-07-22T18:20:14-05:00
New Revision: 37d0568a6593adfe791c1327d99731050540e97a

URL: https://github.com/llvm/llvm-project/commit/37d0568a6593adfe791c1327d99731050540e97a
DIFF: https://github.com/llvm/llvm-project/commit/37d0568a6593adfe791c1327d99731050540e97a.diff

LOG: [Clang] Introduce 'clang-nvlink-wrapper' to work around 'nvlink' (#96561)

Summary:
The `clang-nvlink-wrapper` is a utility that I removed awhile back
during the transition to the new driver. This patch adds back in a new,
upgraded version that does LTO + archive linking. It's not an easy
choice to reintroduce something I happily deleted, but this is the only
way to move forward with improving GPU support in LLVM.

While NVIDIA provides a linker called 'nvlink', its main interface is
very difficult to work with. It does not provide LTO, or static linking,
requires all files to be named a non-standard `.cubin`, and rejects link
jobs that other linkers would be fine with (i.e empty). I have spent a
great deal of time hacking around this in the GPU `libc` implementation,
where I deliberately avoid LTO and static linking and have about 100
lines of hacky CMake dedicated to storing these files in a format that
the clang-linker-wrapper accepts to avoid this limitation.

The main reason I want to re-intorudce this tool is because I am
planning on creating a more standard C/C++ toolchain for GPUs to use.
This will install files like the following.
```
<install>/lib/nvptx64-nvidia-cuda/libc.a
<install>/lib/nvptx64-nvidia-cuda/libc++.a
<install>/lib/nvptx64-nvidia-cuda/libomp.a
<install>/lib/clang/19/lib/nvptx64-nvidia-cuda/libclang_rt.builtins.a
```
Linking in these libraries will then simply require passing `-lc` like
is already done for non-GPU toolchains. However, this doesn't work with
the currently deficient `nvlink` linker, so I consider this a blocking
issue to massively improving the state of building GPU libraries.

In the future we may be able to convince NVIDIA to port their linker to
`ld.lld`, but for now this is the only workable solution that allows us
to hack around the weird behavior of their closed-source software.
This also copies some amount of logic from the clang-linker-wrapper,
but not enough for it to be worthwhile to merge them I feel. In the
future it may be possible to delete that handling from there entirely.

Added: 
    clang/docs/ClangNVLinkWrapper.rst
    clang/test/Driver/nvlink-wrapper.c
    clang/tools/clang-nvlink-wrapper/CMakeLists.txt
    clang/tools/clang-nvlink-wrapper/ClangNVLinkWrapper.cpp
    clang/tools/clang-nvlink-wrapper/NVLinkOpts.td

Modified: 
    clang/docs/index.rst
    clang/lib/Driver/ToolChains/Cuda.cpp
    clang/lib/Driver/ToolChains/Cuda.h
    clang/test/Driver/cuda-cross-compiling.c
    clang/test/lit.cfg.py
    clang/tools/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/clang/docs/ClangNVLinkWrapper.rst b/clang/docs/ClangNVLinkWrapper.rst
new file mode 100644
index 0000000000000..2acdb054572f8
--- /dev/null
+++ b/clang/docs/ClangNVLinkWrapper.rst
@@ -0,0 +1,74 @@
+====================
+Clang nvlink Wrapper
+====================
+
+.. contents::
+   :local:
+
+.. _clang-nvlink-wrapper:
+
+Introduction
+============
+
+This tools works as a wrapper around the NVIDIA ``nvlink`` linker. The purpose
+of this wrapper is to provide an interface similar to the ``ld.lld`` linker
+while still relying on NVIDIA's proprietary linker to produce the final output.
+
+``nvlink`` has a number of known quirks that make it 
diff icult to use in a
+unified offloading setting. For example, it does not accept ``.o`` files as they
+must be named ``.cubin``. Static archives do not work, so passing a ``.a`` will
+provide a linker error. ``nvlink`` also does not support link time optimization
+and ignores many standard linker arguments. This tool works around these issues.
+
+Usage
+=====
+
+This tool can be used with the following options. Any arguments not intended
+only for the linker wrapper will be forwarded to ``nvlink``.
+
+.. code-block:: console
+
+  OVERVIEW: A utility that wraps around the NVIDIA 'nvlink' linker.
+  This enables static linking and LTO handling for NVPTX targets.
+
+  USAGE: clang-nvlink-wrapper [options] <options to passed to nvlink>
+
+  OPTIONS:
+    --arch <value>       Specify the 'sm_' name of the target architecture.
+    --cuda-path=<dir>    Set the system CUDA path
+    --dry-run            Print generated commands without running.
+    --feature <value>    Specify the '+ptx' freature to use for LTO.
+    -g                   Specify that this was a debug compile.
+    -help-hidden         Display all available options
+    -help                Display available options (--help-hidden for more)
+    -L <dir>             Add <dir> to the library search path
+    -l <libname>         Search for library <libname>
+    -mllvm <arg>         Arguments passed to LLVM, including Clang invocations,
+                         for which the '-mllvm' prefix is preserved. Use '-mllvm
+                         --help' for a list of options.
+    -o <path>            Path to file to write output
+    --plugin-opt=jobs=<value>
+                         Number of LTO codegen partitions
+    --plugin-opt=lto-partitions=<value>
+                         Number of LTO codegen partitions
+    --plugin-opt=O<O0, O1, O2, or O3>
+                         Optimization level for LTO
+    --plugin-opt=thinlto<value>
+                         Enable the thin-lto backend
+    --plugin-opt=<value> Arguments passed to LLVM, including Clang invocations,
+                         for which the '-mllvm' prefix is preserved. Use '-mllvm
+                         --help' for a list of options.
+    --save-temps         Save intermediate results
+    --version            Display the version number and exit
+    -v                   Print verbose information
+
+Example
+=======
+
+This tool is intended to be invoked when targeting the NVPTX toolchain directly
+as a cross-compiling target. This can be used to create standalone GPU
+executables with normal linking semantics similar to standard compilation.
+
+.. code-block:: console
+
+  clang --target=nvptx64-nvidia-cuda -march=native -flto=full input.c

diff  --git a/clang/docs/index.rst b/clang/docs/index.rst
index a35a867b96bd7..9bae0bd83243b 100644
--- a/clang/docs/index.rst
+++ b/clang/docs/index.rst
@@ -92,6 +92,7 @@ Using Clang Tools
    ClangFormatStyleOptions
    ClangFormattedStatus
    ClangLinkerWrapper
+   ClangNVLinkWrapper
    ClangOffloadBundler
    ClangOffloadPackager
    ClangRepl

diff  --git a/clang/lib/Driver/ToolChains/Cuda.cpp b/clang/lib/Driver/ToolChains/Cuda.cpp
index 08a4633902654..59453c484ae4f 100644
--- a/clang/lib/Driver/ToolChains/Cuda.cpp
+++ b/clang/lib/Driver/ToolChains/Cuda.cpp
@@ -461,13 +461,6 @@ void NVPTX::Assembler::ConstructJob(Compilation &C, const JobAction &JA,
   CmdArgs.push_back("--output-file");
   std::string OutputFileName = TC.getInputFilename(Output);
 
-  // If we are invoking `nvlink` internally we need to output a `.cubin` file.
-  // FIXME: This should hopefully be removed if NVIDIA updates their tooling.
-  if (!C.getInputArgs().getLastArg(options::OPT_c)) {
-    SmallString<256> Filename(Output.getFilename());
-    llvm::sys::path::replace_extension(Filename, "cubin");
-    OutputFileName = Filename.str();
-  }
   if (Output.isFilename() && OutputFileName != Output.getFilename())
     C.addTempFile(Args.MakeArgString(OutputFileName));
 
@@ -612,12 +605,21 @@ void NVPTX::Linker::ConstructJob(Compilation &C, const JobAction &JA,
   CmdArgs.push_back("-arch");
   CmdArgs.push_back(Args.MakeArgString(GPUArch));
 
+  if (Args.hasArg(options::OPT_ptxas_path_EQ))
+    CmdArgs.push_back(Args.MakeArgString(
+        "--pxtas-path=" + Args.getLastArgValue(options::OPT_ptxas_path_EQ)));
+
   // Add paths specified in LIBRARY_PATH environment variable as -L options.
   addDirectoryList(Args, CmdArgs, "-L", "LIBRARY_PATH");
 
   // Add standard library search paths passed on the command line.
   Args.AddAllArgs(CmdArgs, options::OPT_L);
   getToolChain().AddFilePathLibArgs(Args, CmdArgs);
+  AddLinkerInputs(getToolChain(), Inputs, Args, CmdArgs, JA);
+
+  if (C.getDriver().isUsingLTO())
+    addLTOOptions(getToolChain(), Args, CmdArgs, Output, Inputs[0],
+                  C.getDriver().getLTOMode() == LTOK_Thin);
 
   // Add paths for the default clang library path.
   SmallString<256> DefaultLibPath =
@@ -625,51 +627,12 @@ void NVPTX::Linker::ConstructJob(Compilation &C, const JobAction &JA,
   llvm::sys::path::append(DefaultLibPath, CLANG_INSTALL_LIBDIR_BASENAME);
   CmdArgs.push_back(Args.MakeArgString(Twine("-L") + DefaultLibPath));
 
-  for (const auto &II : Inputs) {
-    if (II.getType() == types::TY_LLVM_IR || II.getType() == types::TY_LTO_IR ||
-        II.getType() == types::TY_LTO_BC || II.getType() == types::TY_LLVM_BC) {
-      C.getDriver().Diag(diag::err_drv_no_linker_llvm_support)
-          << getToolChain().getTripleString();
-      continue;
-    }
-
-    // The 'nvlink' application performs RDC-mode linking when given a '.o'
-    // file and device linking when given a '.cubin' file. We always want to
-    // perform device linking, so just rename any '.o' files.
-    // FIXME: This should hopefully be removed if NVIDIA updates their tooling.
-    if (II.isFilename()) {
-      auto InputFile = getToolChain().getInputFilename(II);
-      if (llvm::sys::path::extension(InputFile) != ".cubin") {
-        // If there are no actions above this one then this is direct input and
-        // we can copy it. Otherwise the input is internal so a `.cubin` file
-        // should exist.
-        if (II.getAction() && II.getAction()->getInputs().size() == 0) {
-          const char *CubinF =
-              Args.MakeArgString(getToolChain().getDriver().GetTemporaryPath(
-                  llvm::sys::path::stem(InputFile), "cubin"));
-          if (llvm::sys::fs::copy_file(InputFile, C.addTempFile(CubinF)))
-            continue;
-
-          CmdArgs.push_back(CubinF);
-        } else {
-          SmallString<256> Filename(InputFile);
-          llvm::sys::path::replace_extension(Filename, "cubin");
-          CmdArgs.push_back(Args.MakeArgString(Filename));
-        }
-      } else {
-        CmdArgs.push_back(Args.MakeArgString(InputFile));
-      }
-    } else if (!II.isNothing()) {
-      II.getInputArg().renderAsInput(Args, CmdArgs);
-    }
-  }
-
   C.addCommand(std::make_unique<Command>(
       JA, *this,
       ResponseFileSupport{ResponseFileSupport::RF_Full, llvm::sys::WEM_UTF8,
                           "--options-file"},
-      Args.MakeArgString(getToolChain().GetProgramPath("nvlink")), CmdArgs,
-      Inputs, Output));
+      Args.MakeArgString(getToolChain().GetProgramPath("clang-nvlink-wrapper")),
+      CmdArgs, Inputs, Output));
 }
 
 void NVPTX::getNVPTXTargetFeatures(const Driver &D, const llvm::Triple &Triple,
@@ -949,11 +912,7 @@ std::string CudaToolChain::getInputFilename(const InputInfo &Input) const {
   if (Input.getType() != types::TY_Object || getDriver().offloadDeviceOnly())
     return ToolChain::getInputFilename(Input);
 
-  // Replace extension for object files with cubin because nvlink relies on
-  // these particular file names.
-  SmallString<256> Filename(ToolChain::getInputFilename(Input));
-  llvm::sys::path::replace_extension(Filename, "cubin");
-  return std::string(Filename);
+  return ToolChain::getInputFilename(Input);
 }
 
 llvm::opt::DerivedArgList *

diff  --git a/clang/lib/Driver/ToolChains/Cuda.h b/clang/lib/Driver/ToolChains/Cuda.h
index 7464d88cb350b..7a6a6fb209012 100644
--- a/clang/lib/Driver/ToolChains/Cuda.h
+++ b/clang/lib/Driver/ToolChains/Cuda.h
@@ -155,6 +155,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXToolChain : public ToolChain {
   bool isPIEDefault(const llvm::opt::ArgList &Args) const override {
     return false;
   }
+  bool HasNativeLLVMSupport() const override { return true; }
   bool isPICDefaultForced() const override { return false; }
   bool SupportsProfiling() const override { return false; }
 
@@ -192,6 +193,8 @@ class LLVM_LIBRARY_VISIBILITY CudaToolChain : public NVPTXToolChain {
     return &HostTC.getTriple();
   }
 
+  bool HasNativeLLVMSupport() const override { return false; }
+
   std::string getInputFilename(const InputInfo &Input) const override;
 
   llvm::opt::DerivedArgList *

diff  --git a/clang/test/Driver/cuda-cross-compiling.c b/clang/test/Driver/cuda-cross-compiling.c
index 1dc4520f485db..42d56cbfcc321 100644
--- a/clang/test/Driver/cuda-cross-compiling.c
+++ b/clang/test/Driver/cuda-cross-compiling.c
@@ -32,8 +32,8 @@
 // RUN:   | FileCheck -check-prefix=ARGS %s
 
 //      ARGS: -cc1" "-triple" "nvptx64-nvidia-cuda" "-S" {{.*}} "-target-cpu" "sm_61" "-target-feature" "+ptx{{[0-9]+}}" {{.*}} "-o" "[[PTX:.+]].s"
-// ARGS-NEXT: ptxas{{.*}}"-m64" "-O0" "--gpu-name" "sm_61" "--output-file" "[[CUBIN:.+]].cubin" "[[PTX]].s" "-c"
-// ARGS-NEXT: nvlink{{.*}}"-o" "a.out" "-arch" "sm_61" {{.*}} "[[CUBIN]].cubin"
+// ARGS-NEXT: ptxas{{.*}}"-m64" "-O0" "--gpu-name" "sm_61" "--output-file" "[[CUBIN:.+]].o" "[[PTX]].s" "-c"
+// ARGS-NEXT: clang-nvlink-wrapper{{.*}}"-o" "a.out" "-arch" "sm_61"{{.*}}"[[CUBIN]].o"
 
 //
 // Test the generated arguments to the CUDA binary utils when targeting NVPTX. 
@@ -55,7 +55,7 @@
 // RUN: %clang -target nvptx64-nvidia-cuda -march=sm_61 -### %t.o 2>&1 \
 // RUN:   | FileCheck -check-prefix=LINK %s
 
-// LINK: nvlink{{.*}}"-o" "a.out" "-arch" "sm_61" {{.*}} "{{.*}}.cubin"
+// LINK: clang-nvlink-wrapper{{.*}}"-o" "a.out" "-arch" "sm_61"{{.*}}[[CUBIN:.+]].o
 
 //
 // Test to ensure that we enable handling global constructors in a freestanding
@@ -72,7 +72,7 @@
 // RUN: %clang -target nvptx64-nvidia-cuda -Wl,-v -Wl,a,b -march=sm_52 -### %s 2>&1 \
 // RUN:   | FileCheck -check-prefix=LINKER-ARGS %s
 
-// LINKER-ARGS: nvlink{{.*}}"-v"{{.*}}"a" "b"
+// LINKER-ARGS: clang-nvlink-wrapper{{.*}}"-v"{{.*}}"a" "b"
 
 // Tests for handling a missing architecture.
 //

diff  --git a/clang/test/Driver/nvlink-wrapper.c b/clang/test/Driver/nvlink-wrapper.c
new file mode 100644
index 0000000000000..fdda93f1f9cdc
--- /dev/null
+++ b/clang/test/Driver/nvlink-wrapper.c
@@ -0,0 +1,65 @@
+// REQUIRES: x86-registered-target
+// REQUIRES: nvptx-registered-target
+
+#if defined(X)
+extern int y;
+int foo() { return y; }
+
+int x = 0;
+#elif defined(Y)
+int y = 42;
+#elif defined(Z)
+int z = 42;
+#elif defined(W)
+int w = 42;
+#elif defined(U)
+extern int x;
+extern int __attribute__((weak)) w;
+
+int bar() {
+  return x + w;
+}
+#else
+extern int y;
+int __attribute__((visibility("hidden"))) x = 999;
+int baz() { return y + x; }
+#endif
+
+// Create various inputs to test basic linking and LTO capabilities. Creating a
+// CUDA binary requires access to the `ptxas` executable, so we just use x64.
+// RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -DX -o %t-x.o
+// RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -DY -o %t-y.o
+// RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -DZ -o %t-z.o
+// RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -DW -o %t-w.o
+// RUN: %clang -cc1 %s -triple x86_64-unknown-linux-gnu -emit-obj -DU -o %t-u.o
+// RUN: llvm-ar rcs %t-x.a %t-x.o
+// RUN: llvm-ar rcs %t-y.a %t-y.o
+// RUN: llvm-ar rcs %t-z.a %t-z.o
+// RUN: llvm-ar rcs %t-w.a %t-w.o
+
+//
+// Check that we forward any unrecognized argument to 'nvlink'.
+//
+// RUN: clang-nvlink-wrapper --dry-run -arch sm_52 %t-u.o -foo -o a.out 2>&1 \
+// RUN:   | FileCheck %s --check-prefix=ARGS
+// ARGS: nvlink{{.*}} -arch sm_52 -foo -o a.out [[INPUT:.+]].cubin
+
+//
+// Check the symbol resolution for static archives. We expect to only link
+// `libx.a` and `liby.a` because extern weak symbols do not extract and `libz.a`
+// is not used at all.
+//
+// RUN: clang-nvlink-wrapper --dry-run %t-x.a %t-u.o %t-y.a %t-z.a %t-w.a \
+// RUN:   -arch sm_52 -o a.out 2>&1 | FileCheck %s --check-prefix=LINK
+// LINK: nvlink{{.*}} -arch sm_52 -o a.out [[INPUT:.+]].cubin {{.*}}-x-{{.*}}.cubin{{.*}}-y-{{.*}}.cubin
+
+// RUN: %clang -cc1 %s -triple nvptx64-nvidia-cuda -emit-llvm-bc -o %t.o
+
+//
+// Check that the LTO interface works and properly preserves symbols used in a
+// regular object file.
+//
+// RUN: clang-nvlink-wrapper --dry-run %t.o %t-u.o %t-y.a \
+// RUN:   -arch sm_52 -o a.out 2>&1 | FileCheck %s --check-prefix=LTO
+// LTO: ptxas{{.*}} -m64 -c [[PTX:.+]].s -O3 -arch sm_52 -o [[CUBIN:.+]].cubin
+// LTO: nvlink{{.*}} -arch sm_52 -o a.out [[CUBIN]].cubin {{.*}}-u-{{.*}}.cubin {{.*}}-y-{{.*}}.cubin

diff  --git a/clang/test/lit.cfg.py b/clang/test/lit.cfg.py
index 2e0fbc2c9e1dd..2bd7501136a10 100644
--- a/clang/test/lit.cfg.py
+++ b/clang/test/lit.cfg.py
@@ -95,6 +95,7 @@
     "llvm-ifs",
     "yaml2obj",
     "clang-linker-wrapper",
+    "clang-nvlink-wrapper",
     "llvm-lto",
     "llvm-lto2",
     "llvm-profdata",

diff  --git a/clang/tools/CMakeLists.txt b/clang/tools/CMakeLists.txt
index bdd8004be3e02..4885afc1584d0 100644
--- a/clang/tools/CMakeLists.txt
+++ b/clang/tools/CMakeLists.txt
@@ -9,6 +9,7 @@ add_clang_subdirectory(clang-format-vs)
 add_clang_subdirectory(clang-fuzzer)
 add_clang_subdirectory(clang-import-test)
 add_clang_subdirectory(clang-linker-wrapper)
+add_clang_subdirectory(clang-nvlink-wrapper)
 add_clang_subdirectory(clang-offload-packager)
 add_clang_subdirectory(clang-offload-bundler)
 add_clang_subdirectory(clang-scan-deps)

diff  --git a/clang/tools/clang-nvlink-wrapper/CMakeLists.txt b/clang/tools/clang-nvlink-wrapper/CMakeLists.txt
new file mode 100644
index 0000000000000..d46f66994cf39
--- /dev/null
+++ b/clang/tools/clang-nvlink-wrapper/CMakeLists.txt
@@ -0,0 +1,44 @@
+set(LLVM_LINK_COMPONENTS
+  ${LLVM_TARGETS_TO_BUILD}
+  BitWriter
+  Core
+  BinaryFormat
+  MC
+  Target
+  TransformUtils
+  Analysis
+  Passes
+  IRReader
+  Object
+  Option
+  Support
+  TargetParser
+  CodeGen
+  LTO
+  )
+
+set(LLVM_TARGET_DEFINITIONS NVLinkOpts.td)
+tablegen(LLVM NVLinkOpts.inc -gen-opt-parser-defs)
+add_public_tablegen_target(NVLinkWrapperOpts)
+
+if(NOT CLANG_BUILT_STANDALONE)
+  set(tablegen_deps intrinsics_gen NVLinkWrapperOpts)
+endif()
+
+add_clang_tool(clang-nvlink-wrapper
+  ClangNVLinkWrapper.cpp
+
+  DEPENDS
+  ${tablegen_deps}
+  )
+
+set(CLANG_NVLINK_WRAPPER_LIB_DEPS
+  clangBasic
+  )
+
+target_compile_options(clang-nvlink-wrapper PRIVATE "-g" "-O0")
+
+target_link_libraries(clang-nvlink-wrapper
+  PRIVATE
+  ${CLANG_NVLINK_WRAPPER_LIB_DEPS}
+  )

diff  --git a/clang/tools/clang-nvlink-wrapper/ClangNVLinkWrapper.cpp b/clang/tools/clang-nvlink-wrapper/ClangNVLinkWrapper.cpp
new file mode 100644
index 0000000000000..5b6d7a0cffb49
--- /dev/null
+++ b/clang/tools/clang-nvlink-wrapper/ClangNVLinkWrapper.cpp
@@ -0,0 +1,781 @@
+//===-- clang-nvlink-wrapper/ClangNVLinkWrapper.cpp - NVIDIA linker util --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------===//
+//
+// This tool wraps around the NVIDIA linker called 'nvlink'. The NVIDIA linker
+// is required to create NVPTX applications, but does not support common
+// features like LTO or archives. This utility wraps around the tool to cover
+// its deficiencies. This tool can be removed once NVIDIA improves their linker
+// or ports it to `ld.lld`.
+//
+//===---------------------------------------------------------------------===//
+
+#include "clang/Basic/Version.h"
+
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/BinaryFormat/Magic.h"
+#include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/CodeGen/CommandFlags.h"
+#include "llvm/IR/DiagnosticPrinter.h"
+#include "llvm/LTO/LTO.h"
+#include "llvm/Object/Archive.h"
+#include "llvm/Object/ArchiveWriter.h"
+#include "llvm/Object/Binary.h"
+#include "llvm/Object/ELFObjectFile.h"
+#include "llvm/Object/IRObjectFile.h"
+#include "llvm/Object/ObjectFile.h"
+#include "llvm/Object/OffloadBinary.h"
+#include "llvm/Option/ArgList.h"
+#include "llvm/Option/OptTable.h"
+#include "llvm/Option/Option.h"
+#include "llvm/Remarks/HotnessThresholdParser.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileOutputBuffer.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/Program.h"
+#include "llvm/Support/Signals.h"
+#include "llvm/Support/StringSaver.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/WithColor.h"
+
+using namespace llvm;
+using namespace llvm::opt;
+using namespace llvm::object;
+
+// Various tools (e.g., llc and opt) duplicate this series of declarations for
+// options related to passes and remarks.
+static cl::opt<bool> RemarksWithHotness(
+    "pass-remarks-with-hotness",
+    cl::desc("With PGO, include profile count in optimization remarks"),
+    cl::Hidden);
+
+static cl::opt<std::optional<uint64_t>, false, remarks::HotnessThresholdParser>
+    RemarksHotnessThreshold(
+        "pass-remarks-hotness-threshold",
+        cl::desc("Minimum profile count required for "
+                 "an optimization remark to be output. "
+                 "Use 'auto' to apply the threshold from profile summary."),
+        cl::value_desc("N or 'auto'"), cl::init(0), cl::Hidden);
+
+static cl::opt<std::string>
+    RemarksFilename("pass-remarks-output",
+                    cl::desc("Output filename for pass remarks"),
+                    cl::value_desc("filename"));
+
+static cl::opt<std::string>
+    RemarksPasses("pass-remarks-filter",
+                  cl::desc("Only record optimization remarks from passes whose "
+                           "names match the given regular expression"),
+                  cl::value_desc("regex"));
+
+static cl::opt<std::string> RemarksFormat(
+    "pass-remarks-format",
+    cl::desc("The format used for serializing remarks (default: YAML)"),
+    cl::value_desc("format"), cl::init("yaml"));
+
+static cl::list<std::string>
+    PassPlugins("load-pass-plugin",
+                cl::desc("Load passes from plugin library"));
+
+static cl::opt<std::string> PassPipeline(
+    "passes",
+    cl::desc(
+        "A textual description of the pass pipeline. To have analysis passes "
+        "available before a certain pass, add 'require<foo-analysis>'. "
+        "'-passes' overrides the pass pipeline (but not all effects) from "
+        "specifying '--opt-level=O?' (O2 is the default) to "
+        "clang-linker-wrapper.  Be sure to include the corresponding "
+        "'default<O?>' in '-passes'."));
+static cl::alias PassPipeline2("p", cl::aliasopt(PassPipeline),
+                               cl::desc("Alias for -passes"));
+
+static void printVersion(raw_ostream &OS) {
+  OS << clang::getClangToolFullVersion("clang-nvlink-wrapper") << '\n';
+}
+
+/// The value of `argv[0]` when run.
+static const char *Executable;
+
+/// Temporary files to be cleaned up.
+static SmallVector<SmallString<128>> TempFiles;
+
+/// Codegen flags for LTO backend.
+static codegen::RegisterCodeGenFlags CodeGenFlags;
+
+namespace {
+// Must not overlap with llvm::opt::DriverFlag.
+enum WrapperFlags { WrapperOnlyOption = (1 << 4) };
+
+enum ID {
+  OPT_INVALID = 0, // This is not an option ID.
+#define OPTION(...) LLVM_MAKE_OPT_ID(__VA_ARGS__),
+#include "NVLinkOpts.inc"
+  LastOption
+#undef OPTION
+};
+
+#define PREFIX(NAME, VALUE)                                                    \
+  static constexpr StringLiteral NAME##_init[] = VALUE;                        \
+  static constexpr ArrayRef<StringLiteral> NAME(NAME##_init,                   \
+                                                std::size(NAME##_init) - 1);
+#include "NVLinkOpts.inc"
+#undef PREFIX
+
+static constexpr OptTable::Info InfoTable[] = {
+#define OPTION(...) LLVM_CONSTRUCT_OPT_INFO(__VA_ARGS__),
+#include "NVLinkOpts.inc"
+#undef OPTION
+};
+
+class WrapperOptTable : public opt::GenericOptTable {
+public:
+  WrapperOptTable() : opt::GenericOptTable(InfoTable) {}
+};
+
+const OptTable &getOptTable() {
+  static const WrapperOptTable *Table = []() {
+    auto Result = std::make_unique<WrapperOptTable>();
+    return Result.release();
+  }();
+  return *Table;
+}
+
+[[noreturn]] void reportError(Error E) {
+  outs().flush();
+  logAllUnhandledErrors(std::move(E), WithColor::error(errs(), Executable));
+  exit(EXIT_FAILURE);
+}
+
+void diagnosticHandler(const DiagnosticInfo &DI) {
+  std::string ErrStorage;
+  raw_string_ostream OS(ErrStorage);
+  DiagnosticPrinterRawOStream DP(OS);
+  DI.print(DP);
+
+  switch (DI.getSeverity()) {
+  case DS_Error:
+    WithColor::error(errs(), Executable) << ErrStorage << "\n";
+    break;
+  case DS_Warning:
+    WithColor::warning(errs(), Executable) << ErrStorage << "\n";
+    break;
+  case DS_Note:
+    WithColor::note(errs(), Executable) << ErrStorage << "\n";
+    break;
+  case DS_Remark:
+    WithColor::remark(errs()) << ErrStorage << "\n";
+    break;
+  }
+}
+
+Expected<StringRef> createTempFile(const ArgList &Args, const Twine &Prefix,
+                                   StringRef Extension) {
+  SmallString<128> OutputFile;
+  if (Args.hasArg(OPT_save_temps)) {
+    (Prefix + "." + Extension).toNullTerminatedStringRef(OutputFile);
+  } else {
+    if (std::error_code EC =
+            sys::fs::createTemporaryFile(Prefix, Extension, OutputFile))
+      return createFileError(OutputFile, EC);
+  }
+
+  TempFiles.emplace_back(std::move(OutputFile));
+  return TempFiles.back();
+}
+
+Expected<std::string> findProgram(StringRef Name, ArrayRef<StringRef> Paths) {
+  ErrorOr<std::string> Path = sys::findProgramByName(Name, Paths);
+  if (!Path)
+    Path = sys::findProgramByName(Name);
+  if (!Path)
+    return createStringError(Path.getError(),
+                             "Unable to find '" + Name + "' in path");
+  return *Path;
+}
+
+std::optional<std::string> findFile(StringRef Dir, StringRef Root,
+                                    const Twine &Name) {
+  SmallString<128> Path;
+  if (Dir.starts_with("="))
+    sys::path::append(Path, Root, Dir.substr(1), Name);
+  else
+    sys::path::append(Path, Dir, Name);
+
+  if (sys::fs::exists(Path))
+    return static_cast<std::string>(Path);
+  return std::nullopt;
+}
+
+std::optional<std::string>
+findFromSearchPaths(StringRef Name, StringRef Root,
+                    ArrayRef<StringRef> SearchPaths) {
+  for (StringRef Dir : SearchPaths)
+    if (std::optional<std::string> File = findFile(Dir, Root, Name))
+      return File;
+  return std::nullopt;
+}
+
+std::optional<std::string>
+searchLibraryBaseName(StringRef Name, StringRef Root,
+                      ArrayRef<StringRef> SearchPaths) {
+  for (StringRef Dir : SearchPaths)
+    if (std::optional<std::string> File =
+            findFile(Dir, Root, "lib" + Name + ".a"))
+      return File;
+  return std::nullopt;
+}
+
+/// Search for static libraries in the linker's library path given input like
+/// `-lfoo` or `-l:libfoo.a`.
+std::optional<std::string> searchLibrary(StringRef Input, StringRef Root,
+                                         ArrayRef<StringRef> SearchPaths) {
+  if (Input.starts_with(":"))
+    return findFromSearchPaths(Input.drop_front(), Root, SearchPaths);
+  return searchLibraryBaseName(Input, Root, SearchPaths);
+}
+
+void printCommands(ArrayRef<StringRef> CmdArgs) {
+  if (CmdArgs.empty())
+    return;
+
+  llvm::errs() << " \"" << CmdArgs.front() << "\" ";
+  llvm::errs() << llvm::join(std::next(CmdArgs.begin()), CmdArgs.end(), " ")
+               << "\n";
+}
+
+/// A minimum symbol interface that provides the necessary information to
+/// extract archive members and resolve LTO symbols.
+struct Symbol {
+  enum Flags {
+    None = 0,
+    Undefined = 1 << 0,
+    Weak = 1 << 1,
+  };
+
+  Symbol() : File(), Flags(None), UsedInRegularObj(false) {}
+
+  Symbol(MemoryBufferRef File, const irsymtab::Reader::SymbolRef Sym)
+      : File(File), Flags(0), UsedInRegularObj(false) {
+    if (Sym.isUndefined())
+      Flags |= Undefined;
+    if (Sym.isWeak())
+      Flags |= Weak;
+  }
+
+  Symbol(MemoryBufferRef File, const SymbolRef Sym)
+      : File(File), Flags(0), UsedInRegularObj(false) {
+    auto FlagsOrErr = Sym.getFlags();
+    if (!FlagsOrErr)
+      reportError(FlagsOrErr.takeError());
+    if (*FlagsOrErr & SymbolRef::SF_Undefined)
+      Flags |= Undefined;
+    if (*FlagsOrErr & SymbolRef::SF_Weak)
+      Flags |= Weak;
+
+    auto NameOrErr = Sym.getName();
+    if (!NameOrErr)
+      reportError(NameOrErr.takeError());
+  }
+
+  bool isWeak() const { return Flags & Weak; }
+  bool isUndefined() const { return Flags & Undefined; }
+
+  MemoryBufferRef File;
+  uint32_t Flags;
+  bool UsedInRegularObj;
+};
+
+Expected<StringRef> runPTXAs(StringRef File, const ArgList &Args) {
+  std::string CudaPath = Args.getLastArgValue(OPT_cuda_path_EQ).str();
+  Expected<std::string> PTXAsPath = Args.getLastArgValue(OPT_ptxas_path);
+  if (PTXAsPath->empty())
+    PTXAsPath = findProgram("ptxas", {CudaPath + "/bin"});
+  if (!PTXAsPath)
+    return PTXAsPath.takeError();
+
+  auto TempFileOrErr = createTempFile(
+      Args, sys::path::stem(Args.getLastArgValue(OPT_o, "a.out")), "cubin");
+  if (!TempFileOrErr)
+    return TempFileOrErr.takeError();
+
+  SmallVector<StringRef> AssemblerArgs({*PTXAsPath, "-m64", "-c", File});
+  if (Args.hasArg(OPT_verbose))
+    AssemblerArgs.push_back("-v");
+  if (Args.hasArg(OPT_g)) {
+    if (Args.hasArg(OPT_O))
+      WithColor::warning(errs(), Executable)
+          << "Optimized debugging not supported, overriding to '-O0'\n";
+    AssemblerArgs.push_back("-O0");
+  } else
+    AssemblerArgs.push_back(
+        Args.MakeArgString("-O" + Args.getLastArgValue(OPT_O, "3")));
+  AssemblerArgs.append({"-arch", Args.getLastArgValue(OPT_arch)});
+  AssemblerArgs.append({"-o", *TempFileOrErr});
+
+  if (Args.hasArg(OPT_dry_run) || Args.hasArg(OPT_verbose))
+    printCommands(AssemblerArgs);
+  if (Args.hasArg(OPT_dry_run))
+    return Args.MakeArgString(*TempFileOrErr);
+  if (sys::ExecuteAndWait(*PTXAsPath, AssemblerArgs))
+    return createStringError("'" + sys::path::filename(*PTXAsPath) + "'" +
+                             " failed");
+  return Args.MakeArgString(*TempFileOrErr);
+}
+
+Expected<std::unique_ptr<lto::LTO>> createLTO(const ArgList &Args) {
+  const llvm::Triple Triple("nvptx64-nvidia-cuda");
+  lto::Config Conf;
+  lto::ThinBackend Backend;
+  unsigned Jobs = 0;
+  if (auto *Arg = Args.getLastArg(OPT_jobs))
+    if (!llvm::to_integer(Arg->getValue(), Jobs) || Jobs == 0)
+      reportError(createStringError("%s: expected a positive integer, got '%s'",
+                                    Arg->getSpelling().data(),
+                                    Arg->getValue()));
+  Backend = lto::createInProcessThinBackend(
+      llvm::heavyweight_hardware_concurrency(Jobs));
+
+  Conf.CPU = Args.getLastArgValue(OPT_arch);
+  Conf.Options = codegen::InitTargetOptionsFromCodeGenFlags(Triple);
+
+  Conf.RemarksFilename = RemarksFilename;
+  Conf.RemarksPasses = RemarksPasses;
+  Conf.RemarksWithHotness = RemarksWithHotness;
+  Conf.RemarksHotnessThreshold = RemarksHotnessThreshold;
+  Conf.RemarksFormat = RemarksFormat;
+
+  Conf.MAttrs = {Args.getLastArgValue(OPT_feature, "").str()};
+  std::optional<CodeGenOptLevel> CGOptLevelOrNone =
+      CodeGenOpt::parseLevel(Args.getLastArgValue(OPT_O, "2")[0]);
+  assert(CGOptLevelOrNone && "Invalid optimization level");
+  Conf.CGOptLevel = *CGOptLevelOrNone;
+  Conf.OptLevel = Args.getLastArgValue(OPT_O, "2")[0] - '0';
+  Conf.DefaultTriple = Triple.getTriple();
+
+  Conf.OptPipeline = PassPipeline;
+  Conf.PassPlugins = PassPlugins;
+
+  Conf.DiagHandler = diagnosticHandler;
+  Conf.CGFileType = CodeGenFileType::AssemblyFile;
+
+  if (Args.hasArg(OPT_lto_emit_llvm)) {
+    Conf.PreCodeGenModuleHook = [&](size_t, const Module &M) {
+      std::error_code EC;
+      raw_fd_ostream LinkedBitcode(Args.getLastArgValue(OPT_o, "a.out"), EC);
+      if (EC)
+        reportError(errorCodeToError(EC));
+      WriteBitcodeToFile(M, LinkedBitcode);
+      return false;
+    };
+  }
+
+  if (Args.hasArg(OPT_save_temps))
+    if (Error Err = Conf.addSaveTemps(
+            (Args.getLastArgValue(OPT_o, "a.out") + ".").str()))
+      return Err;
+
+  unsigned Partitions = 1;
+  if (auto *Arg = Args.getLastArg(OPT_lto_partitions))
+    if (!llvm::to_integer(Arg->getValue(), Partitions) || Partitions == 0)
+      reportError(createStringError("%s: expected a positive integer, got '%s'",
+                                    Arg->getSpelling().data(),
+                                    Arg->getValue()));
+  lto::LTO::LTOKind Kind = Args.hasArg(OPT_thinlto) ? lto::LTO::LTOK_UnifiedThin
+                                                    : lto::LTO::LTOK_Default;
+  return std::make_unique<lto::LTO>(std::move(Conf), Backend, Partitions, Kind);
+}
+
+Expected<bool> getSymbolsFromBitcode(MemoryBufferRef Buffer,
+                                     StringMap<Symbol> &SymTab, bool IsLazy) {
+  Expected<IRSymtabFile> IRSymtabOrErr = readIRSymtab(Buffer);
+  if (!IRSymtabOrErr)
+    return IRSymtabOrErr.takeError();
+  bool Extracted = !IsLazy;
+  StringMap<Symbol> PendingSymbols;
+  for (unsigned I = 0; I != IRSymtabOrErr->Mods.size(); ++I) {
+    for (const auto &IRSym : IRSymtabOrErr->TheReader.module_symbols(I)) {
+      if (IRSym.isFormatSpecific() || !IRSym.isGlobal())
+        continue;
+
+      Symbol &OldSym = !SymTab.count(IRSym.getName()) && IsLazy
+                           ? PendingSymbols[IRSym.getName()]
+                           : SymTab[IRSym.getName()];
+      Symbol Sym = Symbol(Buffer, IRSym);
+      if (OldSym.File.getBuffer().empty())
+        OldSym = Sym;
+
+      bool ResolvesReference =
+          !Sym.isUndefined() &&
+          (OldSym.isUndefined() || (OldSym.isWeak() && !Sym.isWeak())) &&
+          !(OldSym.isWeak() && OldSym.isUndefined() && IsLazy);
+      Extracted |= ResolvesReference;
+
+      Sym.UsedInRegularObj = OldSym.UsedInRegularObj;
+      if (ResolvesReference)
+        OldSym = Sym;
+    }
+  }
+  if (Extracted)
+    for (auto &[Name, Symbol] : PendingSymbols)
+      SymTab[Name] = Symbol;
+  return Extracted;
+}
+
+Expected<bool> getSymbolsFromObject(ObjectFile &ObjFile,
+                                    StringMap<Symbol> &SymTab, bool IsLazy) {
+  bool Extracted = !IsLazy;
+  StringMap<Symbol> PendingSymbols;
+  for (SymbolRef ObjSym : ObjFile.symbols()) {
+    auto NameOrErr = ObjSym.getName();
+    if (!NameOrErr)
+      return NameOrErr.takeError();
+
+    Symbol &OldSym = !SymTab.count(*NameOrErr) && IsLazy
+                         ? PendingSymbols[*NameOrErr]
+                         : SymTab[*NameOrErr];
+    Symbol Sym = Symbol(ObjFile.getMemoryBufferRef(), ObjSym);
+    if (OldSym.File.getBuffer().empty())
+      OldSym = Sym;
+
+    bool ResolvesReference = OldSym.isUndefined() && !Sym.isUndefined() &&
+                             (!OldSym.isWeak() || !IsLazy);
+    Extracted |= ResolvesReference;
+
+    if (ResolvesReference)
+      OldSym = Sym;
+    OldSym.UsedInRegularObj = true;
+  }
+  if (Extracted)
+    for (auto &[Name, Symbol] : PendingSymbols)
+      SymTab[Name] = Symbol;
+  return Extracted;
+}
+
+Expected<bool> getSymbols(MemoryBufferRef Buffer, StringMap<Symbol> &SymTab,
+                          bool IsLazy) {
+  switch (identify_magic(Buffer.getBuffer())) {
+  case file_magic::bitcode: {
+    return getSymbolsFromBitcode(Buffer, SymTab, IsLazy);
+  }
+  case file_magic::elf_relocatable: {
+    Expected<std::unique_ptr<ObjectFile>> ObjFile =
+        ObjectFile::createObjectFile(Buffer);
+    if (!ObjFile)
+      return ObjFile.takeError();
+    return getSymbolsFromObject(**ObjFile, SymTab, IsLazy);
+  }
+  default:
+    return createStringError("Unsupported file type");
+  }
+}
+
+Expected<SmallVector<StringRef>> getInput(const ArgList &Args) {
+  SmallVector<StringRef> LibraryPaths;
+  for (const opt::Arg *Arg : Args.filtered(OPT_library_path))
+    LibraryPaths.push_back(Arg->getValue());
+
+  bool WholeArchive = false;
+  SmallVector<std::pair<std::unique_ptr<MemoryBuffer>, bool>> InputFiles;
+  for (const opt::Arg *Arg : Args.filtered(
+           OPT_INPUT, OPT_library, OPT_whole_archive, OPT_no_whole_archive)) {
+    if (Arg->getOption().matches(OPT_whole_archive) ||
+        Arg->getOption().matches(OPT_no_whole_archive)) {
+      WholeArchive = Arg->getOption().matches(OPT_whole_archive);
+      continue;
+    }
+
+    std::optional<std::string> Filename =
+        Arg->getOption().matches(OPT_library)
+            ? searchLibrary(Arg->getValue(), /*Root=*/"", LibraryPaths)
+            : std::string(Arg->getValue());
+
+    if (!Filename && Arg->getOption().matches(OPT_library))
+      return createStringError("unable to find library -l%s", Arg->getValue());
+
+    if (!Filename || !sys::fs::exists(*Filename) ||
+        sys::fs::is_directory(*Filename))
+      continue;
+
+    ErrorOr<std::unique_ptr<MemoryBuffer>> BufferOrErr =
+        MemoryBuffer::getFileOrSTDIN(*Filename);
+    if (std::error_code EC = BufferOrErr.getError())
+      return createFileError(*Filename, EC);
+
+    MemoryBufferRef Buffer = **BufferOrErr;
+    switch (identify_magic(Buffer.getBuffer())) {
+    case file_magic::bitcode:
+    case file_magic::elf_relocatable:
+      InputFiles.emplace_back(std::move(*BufferOrErr), /*IsLazy=*/false);
+      break;
+    case file_magic::archive: {
+      Expected<std::unique_ptr<llvm::object::Archive>> LibFile =
+          object::Archive::create(Buffer);
+      if (!LibFile)
+        return LibFile.takeError();
+      Error Err = Error::success();
+      for (auto Child : (*LibFile)->children(Err)) {
+        auto ChildBufferOrErr = Child.getMemoryBufferRef();
+        if (!ChildBufferOrErr)
+          return ChildBufferOrErr.takeError();
+        std::unique_ptr<MemoryBuffer> ChildBuffer =
+            MemoryBuffer::getMemBufferCopy(
+                ChildBufferOrErr->getBuffer(),
+                ChildBufferOrErr->getBufferIdentifier());
+        InputFiles.emplace_back(std::move(ChildBuffer), !WholeArchive);
+      }
+      if (Err)
+        return Err;
+      break;
+    }
+    default:
+      return createStringError("Unsupported file type");
+    }
+  }
+
+  bool Extracted = true;
+  StringMap<Symbol> SymTab;
+  SmallVector<std::unique_ptr<MemoryBuffer>> LinkerInput;
+  while (Extracted) {
+    Extracted = false;
+    for (auto &[Input, IsLazy] : InputFiles) {
+      if (!Input)
+        continue;
+
+      // Archive members only extract if they define needed symbols. We will
+      // re-scan all the inputs if any files were extracted for the link job.
+      Expected<bool> ExtractOrErr = getSymbols(*Input, SymTab, IsLazy);
+      if (!ExtractOrErr)
+        return ExtractOrErr.takeError();
+
+      Extracted |= *ExtractOrErr;
+      if (!*ExtractOrErr)
+        continue;
+
+      LinkerInput.emplace_back(std::move(Input));
+    }
+  }
+  InputFiles.clear();
+
+  // Extract any bitcode files to be passed to the LTO pipeline.
+  SmallVector<std::unique_ptr<MemoryBuffer>> BitcodeFiles;
+  for (auto &Input : LinkerInput)
+    if (identify_magic(Input->getBuffer()) == file_magic::bitcode)
+      BitcodeFiles.emplace_back(std::move(Input));
+  llvm::erase_if(LinkerInput, [](const auto &F) { return !F; });
+
+  // Run the LTO pipeline on the extracted inputs.
+  SmallVector<StringRef> Files;
+  if (!BitcodeFiles.empty()) {
+    auto LTOBackendOrErr = createLTO(Args);
+    if (!LTOBackendOrErr)
+      return LTOBackendOrErr.takeError();
+    lto::LTO &LTOBackend = **LTOBackendOrErr;
+    for (auto &BitcodeFile : BitcodeFiles) {
+      Expected<std::unique_ptr<lto::InputFile>> BitcodeFileOrErr =
+          llvm::lto::InputFile::create(*BitcodeFile);
+      if (!BitcodeFileOrErr)
+        return BitcodeFileOrErr.takeError();
+
+      const auto Symbols = (*BitcodeFileOrErr)->symbols();
+      SmallVector<lto::SymbolResolution, 16> Resolutions(Symbols.size());
+      size_t Idx = 0;
+      for (auto &Sym : Symbols) {
+        lto::SymbolResolution &Res = Resolutions[Idx++];
+        Symbol ObjSym = SymTab[Sym.getName()];
+        // We will use this as the prevailing symbol in LTO if it is not
+        // undefined and it is from the file that contained the canonical
+        // definition.
+        Res.Prevailing = !Sym.isUndefined() && ObjSym.File == *BitcodeFile;
+
+        // We need LTO to preseve the following global symbols:
+        // 1) Symbols used in regular objects.
+        // 2) Prevailing symbols that are needed visible to the gpu runtime.
+        Res.VisibleToRegularObj =
+            ObjSym.UsedInRegularObj ||
+            (Res.Prevailing &&
+             (Sym.getVisibility() != GlobalValue::HiddenVisibility &&
+              !Sym.canBeOmittedFromSymbolTable()));
+
+        // Identify symbols that must be exported dynamically and can be
+        // referenced by other files, (i.e. the runtime).
+        Res.ExportDynamic =
+            Sym.getVisibility() != GlobalValue::HiddenVisibility &&
+            !Sym.canBeOmittedFromSymbolTable();
+
+        // The NVIDIA platform does not support any symbol preemption.
+        Res.FinalDefinitionInLinkageUnit = true;
+
+        // We do not support linker redefined symbols (e.g. --wrap) for device
+        // image linking, so the symbols will not be changed after LTO.
+        Res.LinkerRedefined = false;
+      }
+
+      // Add the bitcode file with its resolved symbols to the LTO job.
+      if (Error Err = LTOBackend.add(std::move(*BitcodeFileOrErr), Resolutions))
+        return Err;
+    }
+
+    // Run the LTO job to compile the bitcode.
+    size_t MaxTasks = LTOBackend.getMaxTasks();
+    SmallVector<StringRef> LTOFiles(MaxTasks);
+    auto AddStream =
+        [&](size_t Task,
+            const Twine &ModuleName) -> std::unique_ptr<CachedFileStream> {
+      int FD = -1;
+      auto &TempFile = LTOFiles[Task];
+      if (Args.hasArg(OPT_lto_emit_asm))
+        TempFile = Args.getLastArgValue(OPT_o, "a.out");
+      else {
+        auto TempFileOrErr = createTempFile(
+            Args, sys::path::stem(Args.getLastArgValue(OPT_o, "a.out")), "s");
+        if (!TempFileOrErr)
+          reportError(TempFileOrErr.takeError());
+        TempFile = Args.MakeArgString(*TempFileOrErr);
+      }
+      if (std::error_code EC = sys::fs::openFileForWrite(TempFile, FD))
+        reportError(errorCodeToError(EC));
+      return std::make_unique<CachedFileStream>(
+          std::make_unique<llvm::raw_fd_ostream>(FD, true));
+    };
+
+    if (Error Err = LTOBackend.run(AddStream))
+      return Err;
+
+    if (Args.hasArg(OPT_lto_emit_llvm) || Args.hasArg(OPT_lto_emit_asm))
+      return Files;
+
+    for (StringRef LTOFile : LTOFiles) {
+      auto FileOrErr = runPTXAs(LTOFile, Args);
+      if (!FileOrErr)
+        return FileOrErr.takeError();
+      Files.emplace_back(*FileOrErr);
+    }
+  }
+
+  // Copy all of the input files to a new file ending in `.cubin`. The 'nvlink'
+  // linker requires all NVPTX inputs to have this extension for some reason.
+  for (auto &Input : LinkerInput) {
+    auto TempFileOrErr = createTempFile(
+        Args, sys::path::stem(Input->getBufferIdentifier()), "cubin");
+    if (!TempFileOrErr)
+      return TempFileOrErr.takeError();
+    Expected<std::unique_ptr<FileOutputBuffer>> OutputOrErr =
+        FileOutputBuffer::create(*TempFileOrErr, Input->getBuffer().size());
+    if (!OutputOrErr)
+      return OutputOrErr.takeError();
+    std::unique_ptr<FileOutputBuffer> Output = std::move(*OutputOrErr);
+    llvm::copy(Input->getBuffer(), Output->getBufferStart());
+    if (Error E = Output->commit())
+      return E;
+    Files.emplace_back(Args.MakeArgString(*TempFileOrErr));
+  }
+
+  return Files;
+}
+
+Error runNVLink(ArrayRef<StringRef> Files, const ArgList &Args) {
+  if (Args.hasArg(OPT_lto_emit_asm) || Args.hasArg(OPT_lto_emit_llvm))
+    return Error::success();
+
+  std::string CudaPath = Args.getLastArgValue(OPT_cuda_path_EQ).str();
+  Expected<std::string> NVLinkPath = findProgram("nvlink", {CudaPath + "/bin"});
+  if (!NVLinkPath)
+    return NVLinkPath.takeError();
+
+  ArgStringList NewLinkerArgs;
+  for (const opt::Arg *Arg : Args) {
+    // Do not forward arguments only intended for the linker wrapper.
+    if (Arg->getOption().hasFlag(WrapperOnlyOption))
+      continue;
+
+    // Do not forward any inputs that we have processed.
+    if (Arg->getOption().matches(OPT_INPUT) ||
+        Arg->getOption().matches(OPT_library))
+      continue;
+
+    Arg->render(Args, NewLinkerArgs);
+  }
+
+  llvm::transform(Files, std::back_inserter(NewLinkerArgs),
+                  [&](StringRef Arg) { return Args.MakeArgString(Arg); });
+
+  SmallVector<StringRef> LinkerArgs({*NVLinkPath});
+  if (!Args.hasArg(OPT_o))
+    LinkerArgs.append({"-o", "a.out"});
+  for (StringRef Arg : NewLinkerArgs)
+    LinkerArgs.push_back(Arg);
+
+  if (Args.hasArg(OPT_dry_run) || Args.hasArg(OPT_verbose))
+    printCommands(LinkerArgs);
+  if (Args.hasArg(OPT_dry_run))
+    return Error::success();
+  if (sys::ExecuteAndWait(*NVLinkPath, LinkerArgs))
+    return createStringError("'" + sys::path::filename(*NVLinkPath) + "'" +
+                             " failed");
+  return Error::success();
+}
+
+} // namespace
+
+int main(int argc, char **argv) {
+  InitLLVM X(argc, argv);
+  InitializeAllTargetInfos();
+  InitializeAllTargets();
+  InitializeAllTargetMCs();
+  InitializeAllAsmParsers();
+  InitializeAllAsmPrinters();
+
+  Executable = argv[0];
+  sys::PrintStackTraceOnErrorSignal(argv[0]);
+
+  const OptTable &Tbl = getOptTable();
+  BumpPtrAllocator Alloc;
+  StringSaver Saver(Alloc);
+  auto Args = Tbl.parseArgs(argc, argv, OPT_INVALID, Saver, [&](StringRef Err) {
+    reportError(createStringError(inconvertibleErrorCode(), Err));
+  });
+
+  if (Args.hasArg(OPT_help) || Args.hasArg(OPT_help_hidden)) {
+    Tbl.printHelp(
+        outs(), "clang-nvlink-wrapper [options] <options to passed to nvlink>",
+        "A utility that wraps around the NVIDIA 'nvlink' linker.\n"
+        "This enables static linking and LTO handling for NVPTX targets.",
+        Args.hasArg(OPT_help_hidden), Args.hasArg(OPT_help_hidden));
+    return EXIT_SUCCESS;
+  }
+
+  if (Args.hasArg(OPT_version))
+    printVersion(outs());
+
+  // This forwards '-mllvm' arguments to LLVM if present.
+  SmallVector<const char *> NewArgv = {argv[0]};
+  for (const opt::Arg *Arg : Args.filtered(OPT_mllvm))
+    NewArgv.push_back(Arg->getValue());
+  for (const opt::Arg *Arg : Args.filtered(OPT_plugin_opt))
+    NewArgv.push_back(Arg->getValue());
+  cl::ParseCommandLineOptions(NewArgv.size(), &NewArgv[0]);
+
+  // Get the input files to pass to 'nvlink'.
+  auto FilesOrErr = getInput(Args);
+  if (!FilesOrErr)
+    reportError(FilesOrErr.takeError());
+
+  // Run 'nvlink' on the generated inputs.
+  if (Error Err = runNVLink(*FilesOrErr, Args))
+    reportError(std::move(Err));
+
+  // Remove the temporary files created.
+  if (!Args.hasArg(OPT_save_temps))
+    for (const auto &TempFile : TempFiles)
+      if (std::error_code EC = sys::fs::remove(TempFile))
+        reportError(createFileError(TempFile, EC));
+
+  return EXIT_SUCCESS;
+}

diff  --git a/clang/tools/clang-nvlink-wrapper/NVLinkOpts.td b/clang/tools/clang-nvlink-wrapper/NVLinkOpts.td
new file mode 100644
index 0000000000000..e84b530f2787d
--- /dev/null
+++ b/clang/tools/clang-nvlink-wrapper/NVLinkOpts.td
@@ -0,0 +1,90 @@
+include "llvm/Option/OptParser.td"
+
+def WrapperOnlyOption : OptionFlag;
+
+def help : Flag<["-", "--"], "help">,
+  HelpText<"Display available options (--help-hidden for more)">;
+
+def help_hidden : Flag<["-", "--"], "help-hidden">,
+  HelpText<"Display all available options">;
+
+def verbose : Flag<["-"], "v">, HelpText<"Print verbose information">;
+def version : Flag<["--"], "version">,
+  HelpText<"Display the version number and exit">;
+
+def cuda_path_EQ : Joined<["--"], "cuda-path=">,
+  MetaVarName<"<dir>">, HelpText<"Set the system CUDA path">;
+def ptxas_path_EQ : Joined<["--"], "ptxas-path=">,
+  MetaVarName<"<dir>">, HelpText<"Set the 'ptxas' path">;
+
+def o : JoinedOrSeparate<["-"], "o">, MetaVarName<"<path>">,
+  HelpText<"Path to file to write output">;
+def output : Separate<["--"], "output-file">, Alias<o>, Flags<[HelpHidden]>,
+  HelpText<"Alias for -o">;
+
+def library_path : JoinedOrSeparate<["-"], "L">, MetaVarName<"<dir>">,
+  HelpText<"Add <dir> to the library search path">;
+def library_path_S : Separate<["--", "-"], "library-path">, Flags<[HelpHidden]>,
+  Alias<library_path>;
+def library_path_EQ : Joined<["--", "-"], "library-path=">, Flags<[HelpHidden]>,
+  Alias<library_path>;
+
+def library : JoinedOrSeparate<["-"], "l">, MetaVarName<"<libname>">,
+  HelpText<"Search for library <libname>">;
+def library_S : Separate<["--", "-"], "library">, Flags<[HelpHidden]>,
+  Alias<library_path>;
+def library_EQ : Joined<["--", "-"], "library=">, Flags<[HelpHidden]>,
+  Alias<library_path>;
+
+def arch : Separate<["--", "-"], "arch">,
+  HelpText<"Specify the 'sm_' name of the target architecture.">;
+def : Joined<["--", "-"], "plugin-opt=mcpu=">,
+  Flags<[HelpHidden, WrapperOnlyOption]>, Alias<arch>;
+
+def feature : Separate<["--", "-"], "feature">, Flags<[WrapperOnlyOption]>,
+  HelpText<"Specify the '+ptx' freature to use for LTO.">;
+
+def g : Flag<["-"], "g">, HelpText<"Specify that this was a debug compile.">;
+def debug : Flag<["--"], "debug">, Alias<g>;
+
+def lto_emit_llvm : Flag<["--"], "lto-emit-llvm">, Flags<[WrapperOnlyOption]>,
+  HelpText<"Emit LLVM-IR bitcode">;
+def lto_emit_asm : Flag<["--"], "lto-emit-asm">, Flags<[WrapperOnlyOption]>,
+  HelpText<"Emit assembly code">;
+
+def O : Joined<["--", "-"], "plugin-opt=O">,
+  Flags<[WrapperOnlyOption]>, MetaVarName<"<O0, O1, O2, or O3>">,
+  HelpText<"Optimization level for LTO">;
+
+def thinlto : Joined<["--", "-"], "plugin-opt=thinlto">,
+  Flags<[WrapperOnlyOption]>, HelpText<"Enable the thin-lto backend">;
+def lto_partitions : Joined<["--", "-"], "plugin-opt=lto-partitions=">,
+  Flags<[WrapperOnlyOption]>, HelpText<"Number of LTO codegen partitions">;
+def jobs : Joined<["--", "-"], "plugin-opt=jobs=">,
+  Flags<[WrapperOnlyOption]>, HelpText<"Number of LTO codegen partitions">;
+def : Joined<["--", "-"], "plugin-opt=emit-llvm">,
+  Flags<[WrapperOnlyOption]>, Alias<lto_emit_llvm>;
+def : Joined<["--", "-"], "plugin-opt=emit-asm">,
+  Flags<[WrapperOnlyOption]>, Alias<lto_emit_asm>;
+def plugin_opt : Joined<["--", "-"], "plugin-opt=">, Flags<[WrapperOnlyOption]>,
+  HelpText<"Options passed to LLVM, not including the Clang invocation. Use "
+           "'--plugin-opt=--help' for a list of options.">;
+
+def save_temps : Flag<["--", "-"], "save-temps">,
+  Flags<[WrapperOnlyOption]>, HelpText<"Save intermediate results">;
+
+def whole_archive : Flag<["--", "-"], "whole-archive">,
+  Flags<[WrapperOnlyOption, HelpHidden]>;
+def no_whole_archive : Flag<["--", "-"], "no-whole-archive">,
+  Flags<[WrapperOnlyOption, HelpHidden]>;
+
+def mllvm : Separate<["-"], "mllvm">, Flags<[WrapperOnlyOption]>,
+  MetaVarName<"<arg>">,
+  HelpText<"Arguments passed to LLVM, including Clang invocations, for which "
+           "the '-mllvm' prefix is preserved. Use '-mllvm --help' for a list "
+           "of options.">;
+def mllvm_EQ : Joined<["-"], "mllvm=">, Flags<[HelpHidden]>,
+  Alias<mllvm>;
+
+def dry_run : Flag<["--", "-"], "dry-run">, Flags<[WrapperOnlyOption]>,
+  HelpText<"Print generated commands without running.">;


        


More information about the cfe-commits mailing list