[Mlir-commits] [mlir] [mlir][Target][NVPTX] Add fatbin support to NVPTX compilation. (PR #65398)

Fabian Mora llvmlistbot at llvm.org
Tue Sep 5 11:28:50 PDT 2023


https://github.com/fabianmcg created https://github.com/llvm/llvm-project/pull/65398:

Currently, the NVPTX tool compilation path only calls `ptxas`; thus, the GPU running the binary must be an exact match of the arch of the target, or else the runtime throws an error due to the arch mismatch.

This patch adds a call to `fatbinary`, creating a fat binary with the cubin object and the PTX code, allowing the driver to JIT the PTX at runtime if there's an arch mismatch.

This patch is needed to start migrating the Integration Tests, otherwise there will be a runtime error due to architecture mismatch.

>From 216b84f1aaeaffe79331fb2cc2c262ad72836949 Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Tue, 5 Sep 2023 17:11:36 +0000
Subject: [PATCH] [mlir][Target][NVPTX] Add fatbin support to NVPTX
 compilation.

Currently, the NVPTX tool compilation path only calls `ptxas`; thus, the GPU
running the binary must be an exact match of the arch of the target, or else the
runtime throws an error due to the arch mismatch.

This patch adds a call to `fatbinary`, creating a fat binary with the cubin object
and the PTX code, allowing the driver to JIT the PTX at runtime if there's an
arch mismatch.
---
 mlir/lib/Target/LLVM/NVVM/Target.cpp | 181 +++++++++++++++++----------
 1 file changed, 116 insertions(+), 65 deletions(-)

diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 20b423009ef2c5..e040c22e25fb23 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -184,7 +184,7 @@ class NVPTXSerializer : public SerializeGPUModuleBase {
   // 1. The toolkit path in `targetOptions`.
   // 2. In the system PATH.
   // 3. The path from `getCUDAToolkitPath()`.
-  std::optional<std::string> findPtxas() const;
+  std::optional<std::string> findTool(StringRef tool);
 
   // Target options.
   gpu::TargetOptions targetOptions;
@@ -213,21 +213,21 @@ gpu::GPUModuleOp NVPTXSerializer::getOperation() {
   return dyn_cast<gpu::GPUModuleOp>(&SerializeGPUModuleBase::getOperation());
 }
 
-std::optional<std::string> NVPTXSerializer::findPtxas() const {
+std::optional<std::string> NVPTXSerializer::findTool(StringRef tool) {
   // Find the `ptxas` compiler.
   // 1. Check the toolkit path given in the command line.
   StringRef pathRef = targetOptions.getToolkitPath();
   SmallVector<char, 256> path;
   if (pathRef.size()) {
     path.insert(path.begin(), pathRef.begin(), pathRef.end());
-    llvm::sys::path::append(path, "bin", "ptxas");
+    llvm::sys::path::append(path, "bin", tool);
     if (llvm::sys::fs::can_execute(path))
       return StringRef(path.data(), path.size()).str();
   }
 
   // 2. Check PATH.
   if (std::optional<std::string> ptxasCompiler =
-          llvm::sys::Process::FindInEnvPath("PATH", "ptxas"))
+          llvm::sys::Process::FindInEnvPath("PATH", tool))
     return *ptxasCompiler;
 
   // 3. Check `getCUDAToolkitPath()`.
@@ -235,10 +235,15 @@ std::optional<std::string> NVPTXSerializer::findPtxas() const {
   path.clear();
   if (pathRef.size()) {
     path.insert(path.begin(), pathRef.begin(), pathRef.end());
-    llvm::sys::path::append(path, "bin", "ptxas");
+    llvm::sys::path::append(path, "bin", tool);
     if (llvm::sys::fs::can_execute(path))
       return StringRef(path.data(), path.size()).str();
   }
+  getOperation().emitError()
+      << "Couldn't find the `" << tool
+      << "` binary. Please specify the toolkit "
+         "path, add the compiler to $PATH, or set one of the environment "
+         "variables in `NVVM::getCUDAToolkitPath()`.";
   return std::nullopt;
 }
 
@@ -246,15 +251,13 @@ std::optional<std::string> NVPTXSerializer::findPtxas() const {
 // with this mechanism and let another stage take care of it.
 std::optional<SmallVector<char, 0>>
 NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
-  // Find the PTXAS compiler.
-  std::optional<std::string> ptxasCompiler = findPtxas();
-  if (!ptxasCompiler) {
-    getOperation().emitError()
-        << "Couldn't find the `ptxas` compiler. Please specify the toolkit "
-           "path, add the compiler to $PATH, or set one of the environment "
-           "variables in `NVVM::getCUDAToolkitPath()`.";
+  // Find the `ptxas` & `fatbinary`.
+  std::optional<std::string> ptxasCompiler = findTool("ptxas");
+  std::optional<std::string> fatbinaryTool = findTool("fatbinary");
+  if (!ptxasCompiler || !fatbinaryTool)
     return std::nullopt;
-  }
+
+  Location loc = getOperation().getLoc();
 
   // Base name for all temp files: mlir-<module name>-<target triple>-<chip>.
   std::string basename =
@@ -268,99 +271,147 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
   std::optional<TmpFile> logFile = createTemp(basename, "log");
   if (!logFile)
     return std::nullopt;
-  std::optional<TmpFile> cubinFile = createTemp(basename, "cubin");
-  if (!cubinFile)
+  std::optional<TmpFile> fatbinFile = createTemp(basename, "fatbin");
+  if (!fatbinFile)
     return std::nullopt;
+  Twine cubinFilename = ptxFile->first + ".cubin";
+  TmpFile cubinFile(cubinFilename.str(), llvm::FileRemover(cubinFilename));
 
   std::error_code ec;
   // Dump the PTX to a temp file.
   {
     llvm::raw_fd_ostream ptxStream(ptxFile->first, ec);
     if (ec) {
-      getOperation().emitError()
-          << "Couldn't open the file: `" << ptxFile->first
-          << "`, error message: " << ec.message();
+      emitError(loc) << "Couldn't open the file: `" << ptxFile->first
+                     << "`, error message: " << ec.message();
       return std::nullopt;
     }
     ptxStream << ptxCode;
     if (ptxStream.has_error()) {
-      getOperation().emitError()
-          << "An error occurred while writing the PTX to: `" << ptxFile->first
-          << "`.";
+      emitError(loc) << "An error occurred while writing the PTX to: `"
+                     << ptxFile->first << "`.";
       return std::nullopt;
     }
     ptxStream.flush();
   }
 
-  // Create PTX args.
+  // Command redirects.
+  std::optional<StringRef> redirects[] = {
+      std::nullopt,
+      logFile->first,
+      logFile->first,
+  };
+
+  // Get any extra args passed in `targetOptions`.
+  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
+      targetOptions.tokenizeCmdOptions();
+
+  // Create ptxas args.
   std::string optLevel = std::to_string(this->optLevel);
   SmallVector<StringRef, 12> ptxasArgs(
       {StringRef("ptxas"), StringRef("-arch"), getTarget().getChip(),
-       StringRef(ptxFile->first), StringRef("-o"), StringRef(cubinFile->first),
+       StringRef(ptxFile->first), StringRef("-o"), StringRef(cubinFile.first),
        "--opt-level", optLevel});
 
-  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
-      targetOptions.tokenizeCmdOptions();
-  for (auto arg : cmdOpts.second)
-    ptxasArgs.push_back(arg);
+  bool useFatbin32 = false;
+  for (auto cArg : cmdOpts.second) {
+    // All `cmdOpts` are for `ptxas` except `-32` which passes `-32` to
+    // `fatbinary`, indicating a 32-bit target. By default a 64-bit target is
+    // assumed.
+    if (StringRef arg(cArg); arg != "-32")
+      ptxasArgs.push_back(arg);
+    else
+      useFatbin32 = true;
+  }
 
-  std::optional<StringRef> redirects[] = {
-      std::nullopt,
-      logFile->first,
-      logFile->first,
-  };
+  // Create the `fatbinary` args.
+  StringRef chip = getTarget().getChip();
+  // Remove the arch prefix to obtain the compute capability.
+  chip.consume_front("sm_"), chip.consume_front("compute_");
+  // Embed the cubin object.
+  std::string cubinArg =
+      llvm::formatv("--image3=kind=elf,sm={0},file={1}", chip, cubinFile.first)
+          .str();
+  // Embed the PTX file so the driver can JIT if needed.
+  std::string ptxArg =
+      llvm::formatv("--image3=kind=ptx,sm={0},file={1}", chip, ptxFile->first)
+          .str();
+  SmallVector<StringRef, 6> fatbinArgs({StringRef("fatbinary"),
+                                        useFatbin32 ? "-32" : "-64", cubinArg,
+                                        ptxArg, "--create", fatbinFile->first});
+
+  // Dump tool invocation commands.
+#define DEBUG_TYPE "serialize-to-binary"
+  LLVM_DEBUG({
+    llvm::dbgs() << "Tool invocation for module: "
+                 << getOperation().getNameAttr() << "\n";
+    llvm::interleave(ptxasArgs, llvm::dbgs(), " ");
+    llvm::dbgs() << "\n";
+    llvm::interleave(fatbinArgs, llvm::dbgs(), " ");
+    llvm::dbgs() << "\n";
+  });
+#undef DEBUG_TYPE
 
-  // Invoke PTXAS.
+  // Helper function for printing tool error logs.
   std::string message;
-  if (llvm::sys::ExecuteAndWait(ptxasCompiler.value(), ptxasArgs,
-                                /*Env=*/std::nullopt,
-                                /*Redirects=*/redirects,
-                                /*SecondsToWait=*/0,
-                                /*MemoryLimit=*/0,
-                                /*ErrMsg=*/&message)) {
+  auto emitLogError =
+      [&](StringRef toolName) -> std::optional<SmallVector<char, 0>> {
     if (message.empty()) {
-      llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ptxasStderr =
+      llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> toolStderr =
           llvm::MemoryBuffer::getFile(logFile->first);
-      if (ptxasStderr)
-        getOperation().emitError() << "PTXAS invocation failed. PTXAS log:\n"
-                                   << ptxasStderr->get()->getBuffer();
+      if (toolStderr)
+        emitError(loc) << toolName << " invocation failed. Log:\n"
+                       << toolStderr->get()->getBuffer();
       else
-        getOperation().emitError() << "PTXAS invocation failed.";
+        emitError(loc) << toolName << " invocation failed.";
       return std::nullopt;
     }
-    getOperation().emitError()
-        << "PTXAS invocation failed, error message: " << message;
+    emitError(loc) << toolName
+                   << " invocation failed, error message: " << message;
     return std::nullopt;
-  }
+  };
 
-// Dump the output of PTXAS, helpful if the verbose flag was passed.
+  // Invoke PTXAS.
+  if (llvm::sys::ExecuteAndWait(ptxasCompiler.value(), ptxasArgs,
+                                /*Env=*/std::nullopt,
+                                /*Redirects=*/redirects,
+                                /*SecondsToWait=*/0,
+                                /*MemoryLimit=*/0,
+                                /*ErrMsg=*/&message))
+    return emitLogError("`ptxas`");
+
+  // Invoke `fatbin`.
+  message.clear();
+  if (llvm::sys::ExecuteAndWait(*fatbinaryTool, fatbinArgs,
+                                /*Env=*/std::nullopt,
+                                /*Redirects=*/redirects,
+                                /*SecondsToWait=*/0,
+                                /*MemoryLimit=*/0,
+                                /*ErrMsg=*/&message))
+    return emitLogError("`fatbinary`");
+
+// Dump the output of the tools, helpful if the verbose flag was passed.
 #define DEBUG_TYPE "serialize-to-binary"
   LLVM_DEBUG({
-    llvm::dbgs() << "PTXAS invocation for module: "
-                 << getOperation().getNameAttr() << "\n";
-    llvm::dbgs() << "Command: ";
-    llvm::interleave(ptxasArgs, llvm::dbgs(), " ");
-    llvm::dbgs() << "\n";
-    llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ptxasLog =
+    llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
         llvm::MemoryBuffer::getFile(logFile->first);
-    if (ptxasLog && (*ptxasLog)->getBuffer().size()) {
-      llvm::dbgs() << "Output:\n" << (*ptxasLog)->getBuffer() << "\n";
+    if (logBuffer && (*logBuffer)->getBuffer().size()) {
+      llvm::dbgs() << "Output:\n" << (*logBuffer)->getBuffer() << "\n";
       llvm::dbgs().flush();
     }
   });
 #undef DEBUG_TYPE
 
-  // Read the cubin file.
-  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> cubinBuffer =
-      llvm::MemoryBuffer::getFile(cubinFile->first);
-  if (!cubinBuffer) {
-    getOperation().emitError()
-        << "Couldn't open the file: `" << cubinFile->first
-        << "`, error message: " << cubinBuffer.getError().message();
+  // Read the fatbin.
+  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fatbinBuffer =
+      llvm::MemoryBuffer::getFile(fatbinFile->first);
+  if (!fatbinBuffer) {
+    emitError(loc) << "Couldn't open the file: `" << fatbinFile->first
+                   << "`, error message: " << fatbinBuffer.getError().message();
     return std::nullopt;
   }
-  StringRef cubinStr = (*cubinBuffer)->getBuffer();
-  return SmallVector<char, 0>(cubinStr.begin(), cubinStr.end());
+  StringRef fatbin = (*fatbinBuffer)->getBuffer();
+  return SmallVector<char, 0>(fatbin.begin(), fatbin.end());
 }
 
 #if MLIR_NVPTXCOMPILER_ENABLED == 1



More information about the Mlir-commits mailing list