[Mlir-commits] [mlir] 2224221 - [mlir] Add NVVM to CUBIN conversion to mlir-opt
Christian Sigg
llvmlistbot at llvm.org
Thu Mar 11 01:07:20 PST 2021
Author: Christian Sigg
Date: 2021-03-11T10:07:11+01:00
New Revision: 2224221fb3fa9738bd84221ed048247089187fce
URL: https://github.com/llvm/llvm-project/commit/2224221fb3fa9738bd84221ed048247089187fce
DIFF: https://github.com/llvm/llvm-project/commit/2224221fb3fa9738bd84221ed048247089187fce.diff
LOG: [mlir] Add NVVM to CUBIN conversion to mlir-opt
If MLIR_CUDA_RUNNER_ENABLED, register a 'gpu-to-cubin' conversion pass to mlir-opt.
The next step is to switch CUDA integration tests from mlir-cuda-runner to mlir-opt + mlir-cpu-runner and remove mlir-cuda-runner.
Depends On D98279
Reviewed By: herhut, rriddle, mehdi_amini
Differential Revision: https://reviews.llvm.org/D98203
Added:
mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp
Modified:
mlir/include/mlir/Dialect/GPU/Passes.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/GPUCommon/CMakeLists.txt
mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
mlir/lib/Dialect/GPU/CMakeLists.txt
mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp
mlir/test/Integration/GPU/CUDA/shuffle.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h
index c280026e6de9..6a6a2c0678b6 100644
--- a/mlir/include/mlir/Dialect/GPU/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Passes.h
@@ -53,15 +53,18 @@ class SerializeToBlobPass : public OperationPass<gpu::GPUModuleOp> {
void runOnOperation() final;
+protected:
+ void getDependentDialects(DialectRegistry ®istry) const override;
+
private:
- // Creates the LLVM target machine to generate the ISA.
+ /// Creates the LLVM target machine to generate the ISA.
std::unique_ptr<llvm::TargetMachine> createTargetMachine();
- // Translates the 'getOperation()' result to an LLVM module.
+ /// Translates the 'getOperation()' result to an LLVM module.
virtual std::unique_ptr<llvm::Module>
- translateToLLVMIR(llvm::LLVMContext &llvmContext) = 0;
+ translateToLLVMIR(llvm::LLVMContext &llvmContext);
- // Serializes the target ISA to binary form.
+ /// Serializes the target ISA to binary form.
virtual std::unique_ptr<std::vector<char>>
serializeISA(const std::string &isa) = 0;
@@ -83,6 +86,10 @@ class SerializeToBlobPass : public OperationPass<gpu::GPUModuleOp> {
// Registration
//===----------------------------------------------------------------------===//
+/// Register pass to serialize GPU kernel functions to a CUBIN binary
+/// annotation.
+void registerGpuSerializeToCubinPass();
+
/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/GPU/Passes.h.inc"
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index a3de7e345c70..029df0735959 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -51,6 +51,7 @@ inline void registerAllPasses() {
registerAffinePasses();
registerAsyncPasses();
registerGPUPasses();
+ registerGpuSerializeToCubinPass();
registerLinalgPasses();
LLVM::registerLLVMPasses();
quant::registerQuantPasses();
diff --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
index 53da5e00233a..825bed600aba 100644
--- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
@@ -24,6 +24,8 @@ add_mlir_conversion_library(MLIRGPUToGPURuntimeTransforms
intrinsics_gen
LINK_COMPONENTS
+ Core
+ MC
${AMDGPU_LIBS}
${NVPTX_LIBS}
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
index 2f57524d8425..e8f9a7a46936 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
@@ -61,6 +61,8 @@ class GpuKernelToBlobPass
private:
// Translates the 'getOperation()' result to an LLVM module.
+ // Note: when this class is removed, this function no longer needs to be
+ // virtual.
std::unique_ptr<llvm::Module>
translateToLLVMIR(llvm::LLVMContext &llvmContext) override {
return loweringCallback(getOperation(), llvmContext, "LLVMDialectModule");
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index ed0113800623..4d26ba400f18 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -1,3 +1,11 @@
+if (MLIR_CUDA_CONVERSIONS_ENABLED)
+ set(NVPTX_LIBS
+ NVPTXCodeGen
+ NVPTXDesc
+ NVPTXInfo
+ )
+endif()
+
add_mlir_dialect_library(MLIRGPU
IR/GPUDialect.cpp
Transforms/AllReduceLowering.cpp
@@ -6,6 +14,7 @@ add_mlir_dialect_library(MLIRGPU
Transforms/MemoryPromotion.cpp
Transforms/ParallelLoopMapper.cpp
Transforms/SerializeToBlob.cpp
+ Transforms/SerializeToCubin.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
@@ -13,6 +22,7 @@ add_mlir_dialect_library(MLIRGPU
LINK_COMPONENTS
Core
MC
+ ${NVPTX_LIBS}
DEPENDS
MLIRGPUOpsIncGen
@@ -26,6 +36,7 @@ add_mlir_dialect_library(MLIRGPU
MLIREDSC
MLIRIR
MLIRLLVMIR
+ MLIRLLVMToLLVMIRTranslation
MLIRSCF
MLIRPass
MLIRSideEffectInterfaces
@@ -33,3 +44,42 @@ add_mlir_dialect_library(MLIRGPU
MLIRSupport
MLIRTransformUtils
)
+
+if(MLIR_CUDA_RUNNER_ENABLED)
+ if(NOT MLIR_CUDA_CONVERSIONS_ENABLED)
+ message(SEND_ERROR
+ "Building mlir with cuda support requires the NVPTX backend")
+ endif()
+
+ # Configure CUDA language support. Using check_language first allows us to
+ # give a custom error message.
+ include(CheckLanguage)
+ check_language(CUDA)
+ if (CMAKE_CUDA_COMPILER)
+ enable_language(CUDA)
+ else()
+ message(SEND_ERROR
+ "Building mlir with cuda support requires a working CUDA install")
+ endif()
+
+ # Enable gpu-to-cubin pass.
+ target_compile_definitions(obj.MLIRGPU
+ PRIVATE
+ MLIR_GPU_TO_CUBIN_PASS_ENABLE=1
+ )
+
+ # Add CUDA headers includes and the libcuda.so library.
+ target_include_directories(obj.MLIRGPU
+ PRIVATE
+ ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
+ )
+
+ find_library(CUDA_DRIVER_LIBRARY cuda)
+
+ target_link_libraries(MLIRGPU
+ PRIVATE
+ MLIRNVVMToLLVMIRTranslation
+ ${CUDA_DRIVER_LIBRARY}
+ )
+
+endif()
diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp
index 06c2c508d5ad..7abb34773efc 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp
@@ -14,6 +14,8 @@
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Export.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/Support/TargetRegistry.h"
#include "llvm/Support/TargetSelect.h"
@@ -68,6 +70,12 @@ void gpu::SerializeToBlobPass::runOnOperation() {
getOperation()->setAttr(gpuBinaryAnnotation, attr);
}
+void gpu::SerializeToBlobPass::getDependentDialects(
+ DialectRegistry ®istry) const {
+ registerLLVMDialectTranslation(registry);
+ OperationPass<gpu::GPUModuleOp>::getDependentDialects(registry);
+}
+
std::unique_ptr<llvm::TargetMachine>
gpu::SerializeToBlobPass::createTargetMachine() {
Location loc = getOperation().getLoc();
@@ -87,3 +95,9 @@ gpu::SerializeToBlobPass::createTargetMachine() {
return std::unique_ptr<llvm::TargetMachine>{machine};
}
+
+std::unique_ptr<llvm::Module>
+gpu::SerializeToBlobPass::translateToLLVMIR(llvm::LLVMContext &llvmContext) {
+ return translateModuleToLLVMIR(getOperation(), llvmContext,
+ "LLVMDialectModule");
+}
diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp
new file mode 100644
index 000000000000..a79a8d672775
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToCubin.cpp
@@ -0,0 +1,142 @@
+//===- LowerGPUToCUBIN.cpp - Convert GPU kernel to CUBIN blob -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that serializes a gpu module into CUBIN blob and
+// adds that blob as a string attribute of the module.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/GPU/Passes.h"
+
+#if MLIR_GPU_TO_CUBIN_PASS_ENABLE
+#include "mlir/Pass/Pass.h"
+#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/Support/TargetSelect.h"
+
+#include <cuda.h>
+
+using namespace mlir;
+
+static void emitCudaError(const llvm::Twine &expr, const char *buffer,
+ CUresult result, Location loc) {
+ const char *error;
+ cuGetErrorString(result, &error);
+ emitError(loc, expr.concat(" failed with error code ")
+ .concat(llvm::Twine{error})
+ .concat("[")
+ .concat(buffer)
+ .concat("]"));
+}
+
+#define RETURN_ON_CUDA_ERROR(expr) \
+ do { \
+ if (auto status = (expr)) { \
+ emitCudaError(#expr, jitErrorBuffer, status, loc); \
+ return {}; \
+ } \
+ } while (false)
+
+namespace {
+class SerializeToCubinPass
+ : public PassWrapper<SerializeToCubinPass, gpu::SerializeToBlobPass> {
+public:
+ SerializeToCubinPass();
+
+private:
+ void getDependentDialects(DialectRegistry ®istry) const override;
+
+ // Serializes PTX to CUBIN.
+ std::unique_ptr<std::vector<char>>
+ serializeISA(const std::string &isa) override;
+};
+} // namespace
+
+// Sets the 'option' to 'value' unless it already has a value.
+static void maybeSetOption(Pass::Option<std::string> &option,
+ const char *value) {
+ if (!option.hasValue())
+ option = value;
+}
+
+SerializeToCubinPass::SerializeToCubinPass() {
+ maybeSetOption(this->triple, "nvptx64-nvidia-cuda");
+ maybeSetOption(this->chip, "sm_35");
+ maybeSetOption(this->features, "+ptx60");
+}
+
+void SerializeToCubinPass::getDependentDialects(
+ DialectRegistry ®istry) const {
+ registerNVVMDialectTranslation(registry);
+ gpu::SerializeToBlobPass::getDependentDialects(registry);
+}
+
+std::unique_ptr<std::vector<char>>
+SerializeToCubinPass::serializeISA(const std::string &isa) {
+ Location loc = getOperation().getLoc();
+ char jitErrorBuffer[4096] = {0};
+
+ RETURN_ON_CUDA_ERROR(cuInit(0));
+
+ // Linking requires a device context.
+ CUdevice device;
+ RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0));
+ CUcontext context;
+ RETURN_ON_CUDA_ERROR(cuCtxCreate(&context, 0, device));
+ CUlinkState linkState;
+
+ CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER,
+ CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
+ void *jitOptionsVals[] = {jitErrorBuffer,
+ reinterpret_cast<void *>(sizeof(jitErrorBuffer))};
+
+ RETURN_ON_CUDA_ERROR(cuLinkCreate(2, /* number of jit options */
+ jitOptions, /* jit options */
+ jitOptionsVals, /* jit option values */
+ &linkState));
+
+ auto kernelName = getOperation().getName().str();
+ RETURN_ON_CUDA_ERROR(cuLinkAddData(
+ linkState, CUjitInputType::CU_JIT_INPUT_PTX,
+ const_cast<void *>(static_cast<const void *>(isa.c_str())), isa.length(),
+ kernelName.c_str(), 0, /* number of jit options */
+ nullptr, /* jit options */
+ nullptr /* jit option values */
+ ));
+
+ void *cubinData;
+ size_t cubinSize;
+ RETURN_ON_CUDA_ERROR(cuLinkComplete(linkState, &cubinData, &cubinSize));
+
+ char *cubinAsChar = static_cast<char *>(cubinData);
+ auto result =
+ std::make_unique<std::vector<char>>(cubinAsChar, cubinAsChar + cubinSize);
+
+ // This will also destroy the cubin data.
+ RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState));
+ RETURN_ON_CUDA_ERROR(cuCtxDestroy(context));
+
+ return result;
+}
+
+// Register pass to serialize GPU kernel functions to a CUBIN binary annotation.
+void mlir::registerGpuSerializeToCubinPass() {
+ PassRegistration<SerializeToCubinPass> registerSerializeToCubin(
+ "gpu-to-cubin", "Lower GPU kernel function to CUBIN binary annotations",
+ [] {
+ // Initialize LLVM NVPTX backend.
+ LLVMInitializeNVPTXTarget();
+ LLVMInitializeNVPTXTargetInfo();
+ LLVMInitializeNVPTXTargetMC();
+ LLVMInitializeNVPTXAsmPrinter();
+
+ return std::make_unique<SerializeToCubinPass>();
+ });
+}
+#else // MLIR_GPU_TO_CUBIN_PASS_ENABLE
+void mlir::registerGpuSerializeToCubinPass() {}
+#endif // MLIR_GPU_TO_CUBIN_PASS_ENABLE
diff --git a/mlir/test/Integration/GPU/CUDA/shuffle.mlir b/mlir/test/Integration/GPU/CUDA/shuffle.mlir
index e81bc696fdfb..1c1075debbef 100644
--- a/mlir/test/Integration/GPU/CUDA/shuffle.mlir
+++ b/mlir/test/Integration/GPU/CUDA/shuffle.mlir
@@ -1,6 +1,8 @@
-// RUN: mlir-cuda-runner %s \
-// RUN: -gpu-to-cubin="gpu-binary-annotation=nvvm.cubin" \
+// RUN: mlir-opt %s \
+// RUN: -gpu-kernel-outlining \
+// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin{gpu-binary-annotation=nvvm.cubin})' \
// RUN: -gpu-to-llvm="gpu-binary-annotation=nvvm.cubin" \
+// RUN: | mlir-cpu-runner \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_cuda_runtime%shlibext \
// RUN: --shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
// RUN: --entry-point-result=void \
More information about the Mlir-commits
mailing list