[Mlir-commits] [mlir] c16adb0 - [mlir][Target][NVPTX] Add fatbin support to NVPTX compilation. (#65398)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 7 04:44:45 PDT 2023


Author: Fabian Mora
Date: 2023-09-07T07:44:41-04:00
New Revision: c16adb0dcb1fb64c16d406e02f73242d0cd247e5

URL: https://github.com/llvm/llvm-project/commit/c16adb0dcb1fb64c16d406e02f73242d0cd247e5
DIFF: https://github.com/llvm/llvm-project/commit/c16adb0dcb1fb64c16d406e02f73242d0cd247e5.diff

LOG: [mlir][Target][NVPTX] Add fatbin support to NVPTX compilation. (#65398)

Currently, the NVPTX tool compilation path only calls `ptxas`; thus, the
GPU running the binary must be an exact match of the arch of the target,
or else the runtime throws an error due to the arch mismatch.

This patch adds a call to `fatbinary`, creating a fat binary with the
cubin object and the PTX code, allowing the driver to JIT the PTX at
runtime if there's an arch mismatch.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
    mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
    mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
    mlir/lib/Target/LLVM/NVVM/Target.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index adf950d7b119d68..e0bf560dbd98b92 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -43,10 +43,15 @@ class TargetOptions {
 public:
   /// The target representation of the compilation process.
   typedef enum {
-    offload,  /// The process should produce an offloading representation. For
-              /// the NVVM & ROCDL targets this option produces LLVM IR.
-    assembly, /// The process should produce assembly code.
-    binary    /// The process should produce a binary.
+    offload = 1,  /// The process should produce an offloading representation.
+                  /// For the NVVM & ROCDL targets this option produces LLVM IR.
+    assembly = 2, /// The process should produce assembly code.
+    binary = 4,   /// The process should produce a binary.
+    fatbinary = 8, /// The process should produce a fat binary.
+    binOrFatbin =
+        binary |
+        fatbinary, /// The process should produce a binary or fatbinary. It's up
+                   /// to the target to decide which.
   } CompilationTarget;
 
   /// Constructor initializing the toolkit path, the list of files to link to,
@@ -54,7 +59,7 @@ class TargetOptions {
   /// compilation target is `binary`.
   TargetOptions(StringRef toolkitPath = {},
                 ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
-                CompilationTarget compilationTarget = binary);
+                CompilationTarget compilationTarget = binOrFatbin);
 
   /// Returns the typeID.
   TypeID getTypeID() const;
@@ -80,7 +85,7 @@ class TargetOptions {
   /// appropiate value: ie. `TargetOptions(TypeID::get<DerivedClass>())`.
   TargetOptions(TypeID typeID, StringRef toolkitPath = {},
                 ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
-                CompilationTarget compilationTarget = binary);
+                CompilationTarget compilationTarget = binOrFatbin);
 
   /// Path to the target toolkit.
   std::string toolkitPath;

diff  --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index fc20bd2ed921aea..ba8a6266604e46c 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -77,7 +77,7 @@ def GpuModuleToBinaryPass
            "Extra files to link to.">,
     Option<"cmdOptions", "opts", "std::string", [{""}],
            "Command line options to pass to the tools.">,
-    Option<"compilationTarget", "format", "std::string", [{"bin"}],
+    Option<"compilationTarget", "format", "std::string", [{"binOrFatbin"}],
            "The target representation of the compilation process.">
   ];
 }

diff  --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
index ad21dcc307fa24b..06b7dee6941e1f4 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
@@ -61,6 +61,8 @@ void GpuModuleToBinaryPass::runOnOperation() {
                          .Cases("offloading", "llvm", TargetOptions::offload)
                          .Cases("assembly", "isa", TargetOptions::assembly)
                          .Cases("binary", "bin", TargetOptions::binary)
+                         .Cases("fatbinary", "fatbin", TargetOptions::fatbinary)
+                         .Case("binOrFatbin", TargetOptions::binOrFatbin)
                          .Default(-1);
   if (targetFormat == -1)
     getOperation()->emitError() << "Invalid format specified.";

diff  --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 20b423009ef2c56..13188b1107d928b 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -180,11 +180,12 @@ class NVPTXSerializer : public SerializeGPUModuleBase {
   // Create a temp file.
   std::optional<TmpFile> createTemp(StringRef name, StringRef suffix);
 
-  // Find the PTXAS compiler. The search order is:
+  // Find the `tool` path, where `tool` is the name of the binary to search,
+  // i.e. `ptxas` or `fatbinary`. The search order is:
   // 1. The toolkit path in `targetOptions`.
   // 2. In the system PATH.
   // 3. The path from `getCUDAToolkitPath()`.
-  std::optional<std::string> findPtxas() const;
+  std::optional<std::string> findTool(StringRef tool);
 
   // Target options.
   gpu::TargetOptions targetOptions;
@@ -213,21 +214,21 @@ gpu::GPUModuleOp NVPTXSerializer::getOperation() {
   return dyn_cast<gpu::GPUModuleOp>(&SerializeGPUModuleBase::getOperation());
 }
 
-std::optional<std::string> NVPTXSerializer::findPtxas() const {
-  // Find the `ptxas` compiler.
+std::optional<std::string> NVPTXSerializer::findTool(StringRef tool) {
+  // Find the `tool` path.
   // 1. Check the toolkit path given in the command line.
   StringRef pathRef = targetOptions.getToolkitPath();
   SmallVector<char, 256> path;
   if (pathRef.size()) {
     path.insert(path.begin(), pathRef.begin(), pathRef.end());
-    llvm::sys::path::append(path, "bin", "ptxas");
+    llvm::sys::path::append(path, "bin", tool);
     if (llvm::sys::fs::can_execute(path))
       return StringRef(path.data(), path.size()).str();
   }
 
   // 2. Check PATH.
   if (std::optional<std::string> ptxasCompiler =
-          llvm::sys::Process::FindInEnvPath("PATH", "ptxas"))
+          llvm::sys::Process::FindInEnvPath("PATH", tool))
     return *ptxasCompiler;
 
   // 3. Check `getCUDAToolkitPath()`.
@@ -235,10 +236,15 @@ std::optional<std::string> NVPTXSerializer::findPtxas() const {
   path.clear();
   if (pathRef.size()) {
     path.insert(path.begin(), pathRef.begin(), pathRef.end());
-    llvm::sys::path::append(path, "bin", "ptxas");
+    llvm::sys::path::append(path, "bin", tool);
     if (llvm::sys::fs::can_execute(path))
       return StringRef(path.data(), path.size()).str();
   }
+  getOperation().emitError()
+      << "Couldn't find the `" << tool
+      << "` binary. Please specify the toolkit "
+         "path, add the compiler to $PATH, or set one of the environment "
+         "variables in `NVVM::getCUDAToolkitPath()`.";
   return std::nullopt;
 }
 
@@ -246,15 +252,20 @@ std::optional<std::string> NVPTXSerializer::findPtxas() const {
 // with this mechanism and let another stage take care of it.
 std::optional<SmallVector<char, 0>>
 NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
-  // Find the PTXAS compiler.
-  std::optional<std::string> ptxasCompiler = findPtxas();
-  if (!ptxasCompiler) {
-    getOperation().emitError()
-        << "Couldn't find the `ptxas` compiler. Please specify the toolkit "
-           "path, add the compiler to $PATH, or set one of the environment "
-           "variables in `NVVM::getCUDAToolkitPath()`.";
+  // Determine if the serializer should create a fatbinary with the PTX embeded
+  // or a simple CUBIN binary.
+  const bool createFatbin =
+      (targetOptions.getCompilationTarget() & gpu::TargetOptions::fatbinary) ==
+      gpu::TargetOptions::fatbinary;
+
+  // Find the `ptxas` & `fatbinary` tools.
+  std::optional<std::string> ptxasCompiler = findTool("ptxas");
+  if (!ptxasCompiler)
     return std::nullopt;
-  }
+  std::optional<std::string> fatbinaryTool = findTool("fatbinary");
+  if (createFatbin && !fatbinaryTool)
+    return std::nullopt;
+  Location loc = getOperation().getLoc();
 
   // Base name for all temp files: mlir-<module name>-<target triple>-<chip>.
   std::string basename =
@@ -268,99 +279,154 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
   std::optional<TmpFile> logFile = createTemp(basename, "log");
   if (!logFile)
     return std::nullopt;
-  std::optional<TmpFile> cubinFile = createTemp(basename, "cubin");
-  if (!cubinFile)
+  std::optional<TmpFile> binaryFile = createTemp(basename, "bin");
+  if (!binaryFile)
     return std::nullopt;
+  TmpFile cubinFile;
+  if (createFatbin) {
+    Twine cubinFilename = ptxFile->first + ".cubin";
+    cubinFile = TmpFile(cubinFilename.str(), llvm::FileRemover(cubinFilename));
+  } else {
+    cubinFile.first = binaryFile->first;
+  }
 
   std::error_code ec;
   // Dump the PTX to a temp file.
   {
     llvm::raw_fd_ostream ptxStream(ptxFile->first, ec);
     if (ec) {
-      getOperation().emitError()
-          << "Couldn't open the file: `" << ptxFile->first
-          << "`, error message: " << ec.message();
+      emitError(loc) << "Couldn't open the file: `" << ptxFile->first
+                     << "`, error message: " << ec.message();
       return std::nullopt;
     }
     ptxStream << ptxCode;
     if (ptxStream.has_error()) {
-      getOperation().emitError()
-          << "An error occurred while writing the PTX to: `" << ptxFile->first
-          << "`.";
+      emitError(loc) << "An error occurred while writing the PTX to: `"
+                     << ptxFile->first << "`.";
       return std::nullopt;
     }
     ptxStream.flush();
   }
 
-  // Create PTX args.
+  // Command redirects.
+  std::optional<StringRef> redirects[] = {
+      std::nullopt,
+      logFile->first,
+      logFile->first,
+  };
+
+  // Get any extra args passed in `targetOptions`.
+  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
+      targetOptions.tokenizeCmdOptions();
+
+  // Create ptxas args.
   std::string optLevel = std::to_string(this->optLevel);
   SmallVector<StringRef, 12> ptxasArgs(
       {StringRef("ptxas"), StringRef("-arch"), getTarget().getChip(),
-       StringRef(ptxFile->first), StringRef("-o"), StringRef(cubinFile->first),
+       StringRef(ptxFile->first), StringRef("-o"), StringRef(cubinFile.first),
        "--opt-level", optLevel});
 
-  std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
-      targetOptions.tokenizeCmdOptions();
-  for (auto arg : cmdOpts.second)
-    ptxasArgs.push_back(arg);
+  bool useFatbin32 = false;
+  for (auto cArg : cmdOpts.second) {
+    // All `cmdOpts` are for `ptxas` except `-32` which passes `-32` to
+    // `fatbinary`, indicating a 32-bit target. By default a 64-bit target is
+    // assumed.
+    if (StringRef arg(cArg); arg != "-32")
+      ptxasArgs.push_back(arg);
+    else
+      useFatbin32 = true;
+  }
 
-  std::optional<StringRef> redirects[] = {
-      std::nullopt,
-      logFile->first,
-      logFile->first,
-  };
+  // Create the `fatbinary` args.
+  StringRef chip = getTarget().getChip();
+  // Remove the arch prefix to obtain the compute capability.
+  chip.consume_front("sm_"), chip.consume_front("compute_");
+  // Embed the cubin object.
+  std::string cubinArg =
+      llvm::formatv("--image3=kind=elf,sm={0},file={1}", chip, cubinFile.first)
+          .str();
+  // Embed the PTX file so the driver can JIT if needed.
+  std::string ptxArg =
+      llvm::formatv("--image3=kind=ptx,sm={0},file={1}", chip, ptxFile->first)
+          .str();
+  SmallVector<StringRef, 6> fatbinArgs({StringRef("fatbinary"),
+                                        useFatbin32 ? "-32" : "-64", cubinArg,
+                                        ptxArg, "--create", binaryFile->first});
+
+  // Dump tool invocation commands.
+#define DEBUG_TYPE "serialize-to-binary"
+  LLVM_DEBUG({
+    llvm::dbgs() << "Tool invocation for module: "
+                 << getOperation().getNameAttr() << "\n";
+    llvm::interleave(ptxasArgs, llvm::dbgs(), " ");
+    llvm::dbgs() << "\n";
+    if (createFatbin) {
+      llvm::interleave(fatbinArgs, llvm::dbgs(), " ");
+      llvm::dbgs() << "\n";
+    }
+  });
+#undef DEBUG_TYPE
 
-  // Invoke PTXAS.
+  // Helper function for printing tool error logs.
   std::string message;
-  if (llvm::sys::ExecuteAndWait(ptxasCompiler.value(), ptxasArgs,
-                                /*Env=*/std::nullopt,
-                                /*Redirects=*/redirects,
-                                /*SecondsToWait=*/0,
-                                /*MemoryLimit=*/0,
-                                /*ErrMsg=*/&message)) {
+  auto emitLogError =
+      [&](StringRef toolName) -> std::optional<SmallVector<char, 0>> {
     if (message.empty()) {
-      llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ptxasStderr =
+      llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> toolStderr =
           llvm::MemoryBuffer::getFile(logFile->first);
-      if (ptxasStderr)
-        getOperation().emitError() << "PTXAS invocation failed. PTXAS log:\n"
-                                   << ptxasStderr->get()->getBuffer();
+      if (toolStderr)
+        emitError(loc) << toolName << " invocation failed. Log:\n"
+                       << toolStderr->get()->getBuffer();
       else
-        getOperation().emitError() << "PTXAS invocation failed.";
+        emitError(loc) << toolName << " invocation failed.";
       return std::nullopt;
     }
-    getOperation().emitError()
-        << "PTXAS invocation failed, error message: " << message;
+    emitError(loc) << toolName
+                   << " invocation failed, error message: " << message;
     return std::nullopt;
-  }
+  };
 
-// Dump the output of PTXAS, helpful if the verbose flag was passed.
+  // Invoke PTXAS.
+  if (llvm::sys::ExecuteAndWait(ptxasCompiler.value(), ptxasArgs,
+                                /*Env=*/std::nullopt,
+                                /*Redirects=*/redirects,
+                                /*SecondsToWait=*/0,
+                                /*MemoryLimit=*/0,
+                                /*ErrMsg=*/&message))
+    return emitLogError("`ptxas`");
+
+  // Invoke `fatbin`.
+  message.clear();
+  if (createFatbin && llvm::sys::ExecuteAndWait(*fatbinaryTool, fatbinArgs,
+                                                /*Env=*/std::nullopt,
+                                                /*Redirects=*/redirects,
+                                                /*SecondsToWait=*/0,
+                                                /*MemoryLimit=*/0,
+                                                /*ErrMsg=*/&message))
+    return emitLogError("`fatbinary`");
+
+// Dump the output of the tools, helpful if the verbose flag was passed.
 #define DEBUG_TYPE "serialize-to-binary"
   LLVM_DEBUG({
-    llvm::dbgs() << "PTXAS invocation for module: "
-                 << getOperation().getNameAttr() << "\n";
-    llvm::dbgs() << "Command: ";
-    llvm::interleave(ptxasArgs, llvm::dbgs(), " ");
-    llvm::dbgs() << "\n";
-    llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ptxasLog =
+    llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
         llvm::MemoryBuffer::getFile(logFile->first);
-    if (ptxasLog && (*ptxasLog)->getBuffer().size()) {
-      llvm::dbgs() << "Output:\n" << (*ptxasLog)->getBuffer() << "\n";
+    if (logBuffer && (*logBuffer)->getBuffer().size()) {
+      llvm::dbgs() << "Output:\n" << (*logBuffer)->getBuffer() << "\n";
       llvm::dbgs().flush();
     }
   });
 #undef DEBUG_TYPE
 
-  // Read the cubin file.
-  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> cubinBuffer =
-      llvm::MemoryBuffer::getFile(cubinFile->first);
-  if (!cubinBuffer) {
-    getOperation().emitError()
-        << "Couldn't open the file: `" << cubinFile->first
-        << "`, error message: " << cubinBuffer.getError().message();
+  // Read the fatbin.
+  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> binaryBuffer =
+      llvm::MemoryBuffer::getFile(binaryFile->first);
+  if (!binaryBuffer) {
+    emitError(loc) << "Couldn't open the file: `" << binaryFile->first
+                   << "`, error message: " << binaryBuffer.getError().message();
     return std::nullopt;
   }
-  StringRef cubinStr = (*cubinBuffer)->getBuffer();
-  return SmallVector<char, 0>(cubinStr.begin(), cubinStr.end());
+  StringRef fatbin = (*binaryBuffer)->getBuffer();
+  return SmallVector<char, 0>(fatbin.begin(), fatbin.end());
 }
 
 #if MLIR_NVPTXCOMPILER_ENABLED == 1


        


More information about the Mlir-commits mailing list