[Mlir-commits] [mlir] 2224221 - [mlir] Add NVVM to CUBIN conversion to mlir-opt

Christian Sigg llvmlistbot at llvm.org
Thu Mar 11 01:07:20 PST 2021


Author: Christian Sigg
Date: 2021-03-11T10:07:11+01:00
New Revision: 2224221fb3fa9738bd84221ed048247089187fce

URL: https://github.com/llvm/llvm-project/commit/2224221fb3fa9738bd84221ed048247089187fce
DIFF: https://github.com/llvm/llvm-project/commit/2224221fb3fa9738bd84221ed048247089187fce.diff

LOG: [mlir] Add NVVM to CUBIN conversion to mlir-opt

If MLIR_CUDA_RUNNER_ENABLED, register a 'gpu-to-cubin' conversion pass to mlir-opt.

The next step is to switch CUDA integration tests from mlir-cuda-runner to mlir-opt + mlir-cpu-runner and remove mlir-cuda-runner.

Depends On D98279

Reviewed By: herhut, rriddle, mehdi_amini

Differential Revision: https://reviews.llvm.org/D98203

Added: 
    mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp

Modified: 
    mlir/include/mlir/Dialect/GPU/Passes.h
    mlir/include/mlir/InitAllPasses.h
    mlir/lib/Conversion/GPUCommon/CMakeLists.txt
    mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
    mlir/lib/Dialect/GPU/CMakeLists.txt
    mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp
    mlir/test/Integration/GPU/CUDA/shuffle.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h
index c280026e6de9..6a6a2c0678b6 100644
--- a/mlir/include/mlir/Dialect/GPU/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Passes.h
@@ -53,15 +53,18 @@ class SerializeToBlobPass : public OperationPass<gpu::GPUModuleOp> {
 
   void runOnOperation() final;
 
+protected:
+  void getDependentDialects(DialectRegistry &registry) const override;
+
 private:
-  // Creates the LLVM target machine to generate the ISA.
+  /// Creates the LLVM target machine to generate the ISA.
   std::unique_ptr<llvm::TargetMachine> createTargetMachine();
 
-  // Translates the 'getOperation()' result to an LLVM module.
+  /// Translates the 'getOperation()' result to an LLVM module.
   virtual std::unique_ptr<llvm::Module>
-  translateToLLVMIR(llvm::LLVMContext &llvmContext) = 0;
+  translateToLLVMIR(llvm::LLVMContext &llvmContext);
 
-  // Serializes the target ISA to binary form.
+  /// Serializes the target ISA to binary form.
   virtual std::unique_ptr<std::vector<char>>
   serializeISA(const std::string &isa) = 0;
 
@@ -83,6 +86,10 @@ class SerializeToBlobPass : public OperationPass<gpu::GPUModuleOp> {
 // Registration
 //===----------------------------------------------------------------------===//
 
+/// Register pass to serialize GPU kernel functions to a CUBIN binary
+/// annotation.
+void registerGpuSerializeToCubinPass();
+
 /// Generate the code for registering passes.
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/GPU/Passes.h.inc"

diff  --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index a3de7e345c70..029df0735959 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -51,6 +51,7 @@ inline void registerAllPasses() {
   registerAffinePasses();
   registerAsyncPasses();
   registerGPUPasses();
+  registerGpuSerializeToCubinPass();
   registerLinalgPasses();
   LLVM::registerLLVMPasses();
   quant::registerQuantPasses();

diff  --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
index 53da5e00233a..825bed600aba 100644
--- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
@@ -24,6 +24,8 @@ add_mlir_conversion_library(MLIRGPUToGPURuntimeTransforms
   intrinsics_gen
 
   LINK_COMPONENTS
+  Core
+  MC
   ${AMDGPU_LIBS}
   ${NVPTX_LIBS}
 

diff  --git a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
index 2f57524d8425..e8f9a7a46936 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
@@ -61,6 +61,8 @@ class GpuKernelToBlobPass
 
 private:
   // Translates the 'getOperation()' result to an LLVM module.
+  // Note: when this class is removed, this function no longer needs to be
+  // virtual.
   std::unique_ptr<llvm::Module>
   translateToLLVMIR(llvm::LLVMContext &llvmContext) override {
     return loweringCallback(getOperation(), llvmContext, "LLVMDialectModule");

diff  --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index ed0113800623..4d26ba400f18 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -1,3 +1,11 @@
+if (MLIR_CUDA_CONVERSIONS_ENABLED)
+  set(NVPTX_LIBS
+    NVPTXCodeGen
+    NVPTXDesc
+    NVPTXInfo
+  )
+endif()
+
 add_mlir_dialect_library(MLIRGPU
   IR/GPUDialect.cpp
   Transforms/AllReduceLowering.cpp
@@ -6,6 +14,7 @@ add_mlir_dialect_library(MLIRGPU
   Transforms/MemoryPromotion.cpp
   Transforms/ParallelLoopMapper.cpp
   Transforms/SerializeToBlob.cpp
+  Transforms/SerializeToCubin.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
@@ -13,6 +22,7 @@ add_mlir_dialect_library(MLIRGPU
   LINK_COMPONENTS
   Core
   MC
+  ${NVPTX_LIBS}
 
   DEPENDS
   MLIRGPUOpsIncGen
@@ -26,6 +36,7 @@ add_mlir_dialect_library(MLIRGPU
   MLIREDSC
   MLIRIR
   MLIRLLVMIR
+  MLIRLLVMToLLVMIRTranslation
   MLIRSCF
   MLIRPass
   MLIRSideEffectInterfaces
@@ -33,3 +44,42 @@ add_mlir_dialect_library(MLIRGPU
   MLIRSupport
   MLIRTransformUtils
   )
+
+if(MLIR_CUDA_RUNNER_ENABLED)
+  if(NOT MLIR_CUDA_CONVERSIONS_ENABLED)
+    message(SEND_ERROR
+      "Building mlir with cuda support requires the NVPTX backend")
+  endif()
+
+  # Configure CUDA language support. Using check_language first allows us to
+  # give a custom error message.
+  include(CheckLanguage)
+  check_language(CUDA)
+  if (CMAKE_CUDA_COMPILER)
+    enable_language(CUDA)
+  else()
+    message(SEND_ERROR
+      "Building mlir with cuda support requires a working CUDA install")
+  endif()
+
+  # Enable gpu-to-cubin pass.
+  target_compile_definitions(obj.MLIRGPU
+    PRIVATE
+    MLIR_GPU_TO_CUBIN_PASS_ENABLE=1
+  )
+
+  # Add CUDA headers includes and the libcuda.so library.
+  target_include_directories(obj.MLIRGPU
+    PRIVATE
+    ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
+  )
+
+  find_library(CUDA_DRIVER_LIBRARY cuda)
+
+  target_link_libraries(MLIRGPU
+    PRIVATE
+    MLIRNVVMToLLVMIRTranslation
+    ${CUDA_DRIVER_LIBRARY}
+  )
+
+endif()

diff  --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp
index 06c2c508d5ad..7abb34773efc 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp
@@ -14,6 +14,8 @@
 
 #include "mlir/Dialect/GPU/Passes.h"
 #include "mlir/Pass/Pass.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Export.h"
 #include "llvm/IR/LegacyPassManager.h"
 #include "llvm/Support/TargetRegistry.h"
 #include "llvm/Support/TargetSelect.h"
@@ -68,6 +70,12 @@ void gpu::SerializeToBlobPass::runOnOperation() {
   getOperation()->setAttr(gpuBinaryAnnotation, attr);
 }
 
+void gpu::SerializeToBlobPass::getDependentDialects(
+    DialectRegistry &registry) const {
+  registerLLVMDialectTranslation(registry);
+  OperationPass<gpu::GPUModuleOp>::getDependentDialects(registry);
+}
+
 std::unique_ptr<llvm::TargetMachine>
 gpu::SerializeToBlobPass::createTargetMachine() {
   Location loc = getOperation().getLoc();
@@ -87,3 +95,9 @@ gpu::SerializeToBlobPass::createTargetMachine() {
 
   return std::unique_ptr<llvm::TargetMachine>{machine};
 }
+
+std::unique_ptr<llvm::Module>
+gpu::SerializeToBlobPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
+  return translateModuleToLLVMIR(getOperation(), llvmContext,
+                                 "LLVMDialectModule");
+}

diff  --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp
new file mode 100644
index 000000000000..a79a8d672775
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp
@@ -0,0 +1,142 @@
+//===- LowerGPUToCUBIN.cpp - Convert GPU kernel to CUBIN blob -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that serializes a gpu module into CUBIN blob and
+// adds that blob as a string attribute of the module.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/GPU/Passes.h"
+
+#if MLIR_GPU_TO_CUBIN_PASS_ENABLE
+#include "mlir/Pass/Pass.h"
+#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/Support/TargetSelect.h"
+
+#include <cuda.h>
+
+using namespace mlir;
+
+static void emitCudaError(const llvm::Twine &expr, const char *buffer,
+                          CUresult result, Location loc) {
+  const char *error;
+  cuGetErrorString(result, &error);
+  emitError(loc, expr.concat(" failed with error code ")
+                     .concat(llvm::Twine{error})
+                     .concat("[")
+                     .concat(buffer)
+                     .concat("]"));
+}
+
+#define RETURN_ON_CUDA_ERROR(expr)                                             \
+  do {                                                                         \
+    if (auto status = (expr)) {                                                \
+      emitCudaError(#expr, jitErrorBuffer, status, loc);                       \
+      return {};                                                               \
+    }                                                                          \
+  } while (false)
+
+namespace {
+class SerializeToCubinPass
+    : public PassWrapper<SerializeToCubinPass, gpu::SerializeToBlobPass> {
+public:
+  SerializeToCubinPass();
+
+private:
+  void getDependentDialects(DialectRegistry &registry) const override;
+
+  // Serializes PTX to CUBIN.
+  std::unique_ptr<std::vector<char>>
+  serializeISA(const std::string &isa) override;
+};
+} // namespace
+
+// Sets the 'option' to 'value' unless it already has a value.
+static void maybeSetOption(Pass::Option<std::string> &option,
+                           const char *value) {
+  if (!option.hasValue())
+    option = value;
+}
+
+SerializeToCubinPass::SerializeToCubinPass() {
+  maybeSetOption(this->triple, "nvptx64-nvidia-cuda");
+  maybeSetOption(this->chip, "sm_35");
+  maybeSetOption(this->features, "+ptx60");
+}
+
+void SerializeToCubinPass::getDependentDialects(
+    DialectRegistry &registry) const {
+  registerNVVMDialectTranslation(registry);
+  gpu::SerializeToBlobPass::getDependentDialects(registry);
+}
+
+std::unique_ptr<std::vector<char>>
+SerializeToCubinPass::serializeISA(const std::string &isa) {
+  Location loc = getOperation().getLoc();
+  char jitErrorBuffer[4096] = {0};
+
+  RETURN_ON_CUDA_ERROR(cuInit(0));
+
+  // Linking requires a device context.
+  CUdevice device;
+  RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0));
+  CUcontext context;
+  RETURN_ON_CUDA_ERROR(cuCtxCreate(&context, 0, device));
+  CUlinkState linkState;
+
+  CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER,
+                               CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
+  void *jitOptionsVals[] = {jitErrorBuffer,
+                            reinterpret_cast<void *>(sizeof(jitErrorBuffer))};
+
+  RETURN_ON_CUDA_ERROR(cuLinkCreate(2,              /* number of jit options */
+                                    jitOptions,     /* jit options */
+                                    jitOptionsVals, /* jit option values */
+                                    &linkState));
+
+  auto kernelName = getOperation().getName().str();
+  RETURN_ON_CUDA_ERROR(cuLinkAddData(
+      linkState, CUjitInputType::CU_JIT_INPUT_PTX,
+      const_cast<void *>(static_cast<const void *>(isa.c_str())), isa.length(),
+      kernelName.c_str(), 0, /* number of jit options */
+      nullptr,               /* jit options */
+      nullptr                /* jit option values */
+      ));
+
+  void *cubinData;
+  size_t cubinSize;
+  RETURN_ON_CUDA_ERROR(cuLinkComplete(linkState, &cubinData, &cubinSize));
+
+  char *cubinAsChar = static_cast<char *>(cubinData);
+  auto result =
+      std::make_unique<std::vector<char>>(cubinAsChar, cubinAsChar + cubinSize);
+
+  // This will also destroy the cubin data.
+  RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState));
+  RETURN_ON_CUDA_ERROR(cuCtxDestroy(context));
+
+  return result;
+}
+
+// Register pass to serialize GPU kernel functions to a CUBIN binary annotation.
+void mlir::registerGpuSerializeToCubinPass() {
+  PassRegistration<SerializeToCubinPass> registerSerializeToCubin(
+      "gpu-to-cubin", "Lower GPU kernel function to CUBIN binary annotations",
+      [] {
+        // Initialize LLVM NVPTX backend.
+        LLVMInitializeNVPTXTarget();
+        LLVMInitializeNVPTXTargetInfo();
+        LLVMInitializeNVPTXTargetMC();
+        LLVMInitializeNVPTXAsmPrinter();
+
+        return std::make_unique<SerializeToCubinPass>();
+      });
+}
+#else  // MLIR_GPU_TO_CUBIN_PASS_ENABLE
+void mlir::registerGpuSerializeToCubinPass() {}
+#endif // MLIR_GPU_TO_CUBIN_PASS_ENABLE

diff  --git a/mlir/test/Integration/GPU/CUDA/shuffle.mlir b/mlir/test/Integration/GPU/CUDA/shuffle.mlir
index e81bc696fdfb..1c1075debbef 100644
--- a/mlir/test/Integration/GPU/CUDA/shuffle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/shuffle.mlir
@@ -1,6 +1,8 @@
-// RUN: mlir-cuda-runner %s \
-// RUN:   -gpu-to-cubin="gpu-binary-annotation=nvvm.cubin" \
+// RUN: mlir-opt %s \
+// RUN:   -gpu-kernel-outlining \
+// RUN:   -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin{gpu-binary-annotation=nvvm.cubin})' \
 // RUN:   -gpu-to-llvm="gpu-binary-annotation=nvvm.cubin" \
+// RUN: | mlir-cpu-runner \
 // RUN:   --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \
 // RUN:   --shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
 // RUN:   --entry-point-result=void \


        


More information about the Mlir-commits mailing list