[Mlir-commits] [mlir] 061fb8e - [mlir][gpu][mlir-cuda-runner] Refactor ConvertKernelFuncToCubin to be generic.
Wen-Heng Chung
llvmlistbot at llvm.org
Thu May 28 07:09:06 PDT 2020
Author: Wen-Heng (Jack) Chung
Date: 2020-05-28T09:08:28-05:00
New Revision: 061fb8eb2d9f6ffa05f2b57670c918c477ca7f36
URL: https://github.com/llvm/llvm-project/commit/061fb8eb2d9f6ffa05f2b57670c918c477ca7f36
DIFF: https://github.com/llvm/llvm-project/commit/061fb8eb2d9f6ffa05f2b57670c918c477ca7f36.diff
LOG: [mlir][gpu][mlir-cuda-runner] Refactor ConvertKernelFuncToCubin to be generic.
Make ConvertKernelFuncToCubin pass to be generic:
- Rename to ConvertKernelFuncToBlob.
- Allow specifying triple, target chip, target features.
- Initializing LLVM backend is supplied by a callback function.
- Lowering process from MLIR module to LLVM module is via another callback.
- Change mlir-cuda-runner to adopt the revised pass.
- Add new tests for lowering to ROCm HSA code object (HSACO).
- Tests for CUDA and ROCm are kept in separate directories.
Differential Revision: https://reviews.llvm.org/D80142
Added:
mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
mlir/test/Conversion/GPUToROCm/lit.local.cfg
mlir/test/Conversion/GPUToROCm/lower-rocdl-kernel-to-hsaco.mlir
mlir/test/lib/Transforms/TestConvertGPUKernelToHsaco.cpp
Modified:
mlir/CMakeLists.txt
mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
mlir/include/mlir/InitAllPasses.h
mlir/lib/Conversion/CMakeLists.txt
mlir/lib/Conversion/GPUCommon/CMakeLists.txt
mlir/test/lib/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/TestConvertGPUKernelToCubin.cpp
mlir/test/lit.site.cfg.py.in
mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt
mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
################################################################################
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 7c2c5978c44e..0cf1e8d44516 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -31,6 +31,15 @@ endif()
# TODO: we should use a config.h file like LLVM does
add_definitions(-DMLIR_CUDA_CONVERSIONS_ENABLED=${MLIR_CUDA_CONVERSIONS_ENABLED})
+# Build the ROCm conversions and run according tests if the AMDGPU backend
+# is available
+if ("AMDGPU" IN_LIST LLVM_TARGETS_TO_BUILD)
+ set(MLIR_ROCM_CONVERSIONS_ENABLED 1)
+else()
+ set(MLIR_ROCM_CONVERSIONS_ENABLED 0)
+endif()
+add_definitions(-DMLIR_ROCM_CONVERSIONS_ENABLED=${MLIR_ROCM_CONVERSIONS_ENABLED})
+
set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner")
set(MLIR_VULKAN_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir Vulkan runner")
diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
index 791d859f6414..2c4b3dc6ac88 100644
--- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
+++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h
@@ -9,19 +9,33 @@
#define MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_
#include "mlir/Support/LLVM.h"
-#include <functional>
-#include <memory>
-#include <string>
+#include "llvm/IR/Module.h"
#include <vector>
namespace mlir {
class Location;
+class LogicalResult;
class ModuleOp;
+class Operation;
template <typename T>
class OperationPass;
+namespace gpu {
+class GPUModuleOp;
+} // namespace gpu
+
+namespace LLVM {
+class LLVMDialect;
+} // namespace LLVM
+
+using OwnedBlob = std::unique_ptr<std::vector<char>>;
+using BlobGenerator =
+ std::function<OwnedBlob(const std::string &, Location, StringRef)>;
+using LoweringCallback =
+ std::function<std::unique_ptr<llvm::Module>(Operation *)>;
+
/// Creates a pass to convert a gpu.launch_func operation into a sequence of
/// GPU runtime calls.
///
@@ -31,6 +45,34 @@ class OperationPass;
std::unique_ptr<OperationPass<ModuleOp>>
createConvertGpuLaunchFuncToGpuRuntimeCallsPass();
+/// Creates a pass to convert kernel functions into GPU target object blobs.
+///
+/// This transformation takes the body of each function that is annotated with
+/// the 'gpu.kernel' attribute, copies it to a new LLVM module, compiles the
+/// module with help of the GPU backend to target object and then invokes
+/// the provided blobGenerator to produce a binary blob. Such blob is then
+/// attached as a string attribute to the kernel function.
+///
+/// Following callbacks are to be provided by user:
+/// - loweringCallback : lower the module to an LLVM module.
+/// - blobGenerator : build a blob executable on target GPU.
+///
+/// Information wrt LLVM backend are to be supplied by user:
+/// - triple : target triple to be used.
+/// - targetChip : mcpu to be used.
+/// - features : target-specific features to be used.
+///
+/// Information about result attribute is to be specified by user:
+/// - gpuBinaryAnnotation : the name of the attribute which contains the blob.
+///
+/// After the transformation, the body of the kernel function is removed (i.e.,
+/// it is turned into a declaration).
+std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
+createConvertGPUKernelToBlobPass(LoweringCallback loweringCallback,
+ BlobGenerator blobGenerator, StringRef triple,
+ StringRef targetChip, StringRef features,
+ StringRef gpuBinaryAnnotation);
+
} // namespace mlir
#endif // MLIR_CONVERSION_GPUCOMMON_GPUCOMMONPASS_H_
diff --git a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h b/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
deleted file mode 100644
index bac13d6d7ccb..000000000000
--- a/mlir/include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h
+++ /dev/null
@@ -1,50 +0,0 @@
-//===- GPUToCUDAPass.h - MLIR CUDA runtime support --------------*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-#ifndef MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
-#define MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
-
-#include "mlir/Support/LLVM.h"
-#include <functional>
-#include <memory>
-#include <string>
-#include <vector>
-
-namespace mlir {
-
-class Location;
-class ModuleOp;
-
-template <typename T> class OperationPass;
-
-namespace gpu {
-class GPUModuleOp;
-} // namespace gpu
-
-namespace LLVM {
-class LLVMDialect;
-} // namespace LLVM
-
-using OwnedCubin = std::unique_ptr<std::vector<char>>;
-using CubinGenerator =
- std::function<OwnedCubin(const std::string &, Location, StringRef)>;
-
-/// Creates a pass to convert kernel functions into CUBIN blobs.
-///
-/// This transformation takes the body of each function that is annotated with
-/// the 'nvvm.kernel' attribute, copies it to a new LLVM module, compiles the
-/// module with help of the nvptx backend to PTX and then invokes the provided
-/// cubinGenerator to produce a binary blob (the cubin). Such blob is then
-/// attached as a string attribute named 'nvvm.cubin' to the kernel function.
-/// After the transformation, the body of the kernel function is removed (i.e.,
-/// it is turned into a declaration).
-std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator);
-
-} // namespace mlir
-
-#endif // MLIR_CONVERSION_GPUTOCUDA_GPUTOCUDAPASS_H_
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 66083f671cde..fb2ac1ee086f 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -16,7 +16,6 @@
#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
-#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 248f5f5a0e6c..8b70e6523106 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -1,7 +1,6 @@
add_subdirectory(AffineToStandard)
add_subdirectory(AVX512ToLLVM)
add_subdirectory(GPUCommon)
-add_subdirectory(GPUToCUDA)
add_subdirectory(GPUToNVVM)
add_subdirectory(GPUToROCDL)
add_subdirectory(GPUToSPIRV)
diff --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
index a01fb7676b10..eb7d21f66f73 100644
--- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt
@@ -1,9 +1,6 @@
-set(SOURCES
- ConvertLaunchFuncToRuntimeCalls.cpp
-)
-
add_mlir_conversion_library(MLIRGPUtoGPURuntimeTransforms
- ${SOURCES}
+ ConvertLaunchFuncToRuntimeCalls.cpp
+ ConvertKernelFuncToBlob.cpp
DEPENDS
MLIRConversionPassIncGen
diff --git a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
new file mode 100644
index 000000000000..cf41523d3b29
--- /dev/null
+++ b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp
@@ -0,0 +1,168 @@
+//===- ConvertKernelFuncToBlob.cpp - MLIR GPU lowering passes -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert gpu kernel functions into a
+// corresponding binary blob that can be executed on a GPU. Currently
+// only translates the function itself but no dependencies.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+
+#include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/Mutex.h"
+#include "llvm/Support/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+
+using namespace mlir;
+
+namespace {
+
+/// A pass converting tagged kernel modules to a blob with target instructions.
+///
+/// If tagged as a kernel module, each contained function is translated to
+/// user-specified IR. A user provided BlobGenerator then compiles the IR to
+/// GPU binary code, which is then attached as an attribute to the function.
+/// The function body is erased.
+class GpuKernelToBlobPass
+ : public PassWrapper<GpuKernelToBlobPass, OperationPass<gpu::GPUModuleOp>> {
+public:
+ GpuKernelToBlobPass(LoweringCallback loweringCallback,
+ BlobGenerator blobGenerator, StringRef triple,
+ StringRef targetChip, StringRef features,
+ StringRef gpuBinaryAnnotation)
+ : loweringCallback(loweringCallback), blobGenerator(blobGenerator),
+ triple(triple), targetChip(targetChip), features(features),
+ blobAnnotation(gpuBinaryAnnotation) {}
+
+ void runOnOperation() override {
+ gpu::GPUModuleOp module = getOperation();
+
+ // Lock access to the llvm context.
+ llvm::sys::SmartScopedLock<true> scopedLock(
+ module.getContext()
+ ->getRegisteredDialect<LLVM::LLVMDialect>()
+ ->getLLVMContextMutex());
+
+ // Lower the module to a llvm module.
+ std::unique_ptr<llvm::Module> llvmModule = loweringCallback(module);
+ if (!llvmModule)
+ return signalPassFailure();
+
+ // Translate the llvm module to a target blob and attach the result as
+ // attribute to the module.
+ if (auto blobAttr = translateGPUModuleToBinaryAnnotation(
+ *llvmModule, module.getLoc(), module.getName()))
+ module.setAttr(blobAnnotation, blobAttr);
+ else
+ signalPassFailure();
+ }
+
+private:
+ std::string translateModuleToISA(llvm::Module &module,
+ llvm::TargetMachine &targetMachine);
+
+ /// Converts llvmModule to a blob with target instructions using the
+ /// user-provided generator. Location is used for error reporting and name is
+ /// forwarded to the blob generator to use in its logging mechanisms.
+ OwnedBlob convertModuleToBlob(llvm::Module &llvmModule, Location loc,
+ StringRef name);
+
+ /// Translates llvmModule to a blob with target instructions and returns the
+ /// result as attribute.
+ StringAttr translateGPUModuleToBinaryAnnotation(llvm::Module &llvmModule,
+ Location loc, StringRef name);
+
+ LoweringCallback loweringCallback;
+ BlobGenerator blobGenerator;
+ llvm::Triple triple;
+ StringRef targetChip;
+ StringRef features;
+ StringRef blobAnnotation;
+};
+
+} // anonymous namespace
+
+std::string
+GpuKernelToBlobPass::translateModuleToISA(llvm::Module &module,
+ llvm::TargetMachine &targetMachine) {
+ std::string targetISA;
+ {
+ // Clone the llvm module into a new context to enable concurrent compilation
+ // with multiple threads.
+ llvm::LLVMContext llvmContext;
+ auto clone = LLVM::cloneModuleIntoNewContext(&llvmContext, &module);
+
+ llvm::raw_string_ostream stream(targetISA);
+ llvm::buffer_ostream pstream(stream);
+ llvm::legacy::PassManager codegenPasses;
+ targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr,
+ llvm::CGFT_AssemblyFile);
+ codegenPasses.run(*clone);
+ }
+
+ return targetISA;
+}
+
+OwnedBlob GpuKernelToBlobPass::convertModuleToBlob(llvm::Module &llvmModule,
+ Location loc,
+ StringRef name) {
+ std::unique_ptr<llvm::TargetMachine> targetMachine;
+ {
+ std::string error;
+ const llvm::Target *target =
+ llvm::TargetRegistry::lookupTarget("", triple, error);
+ if (target == nullptr) {
+ emitError(loc, "cannot initialize target triple");
+ return {};
+ }
+ targetMachine.reset(target->createTargetMachine(triple.str(), targetChip,
+ features, {}, {}));
+ }
+
+ llvmModule.setDataLayout(targetMachine->createDataLayout());
+
+ auto targetISA = translateModuleToISA(llvmModule, *targetMachine);
+
+ return blobGenerator(targetISA, loc, name);
+}
+
+StringAttr GpuKernelToBlobPass::translateGPUModuleToBinaryAnnotation(
+ llvm::Module &llvmModule, Location loc, StringRef name) {
+ auto blob = convertModuleToBlob(llvmModule, loc, name);
+ if (!blob)
+ return {};
+ return StringAttr::get({blob->data(), blob->size()}, loc->getContext());
+}
+
+std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
+mlir::createConvertGPUKernelToBlobPass(LoweringCallback loweringCallback,
+ BlobGenerator blobGenerator,
+ StringRef triple, StringRef targetChip,
+ StringRef features,
+ StringRef gpuBinaryAnnotation) {
+ return std::make_unique<GpuKernelToBlobPass>(loweringCallback, blobGenerator,
+ triple, targetChip, features,
+ gpuBinaryAnnotation);
+}
diff --git a/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt b/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt
deleted file mode 100644
index 90cc8d573ff3..000000000000
--- a/mlir/lib/Conversion/GPUToCUDA/CMakeLists.txt
+++ /dev/null
@@ -1,35 +0,0 @@
-set(LLVM_OPTIONAL_SOURCES
- ConvertKernelFuncToCubin.cpp
-)
-
-if (MLIR_CUDA_CONVERSIONS_ENABLED)
- set(NVPTX_LIBS
- MC
- NVPTXCodeGen
- NVPTXDesc
- NVPTXInfo
- )
-
- add_mlir_conversion_library(MLIRGPUtoCUDATransforms
- ConvertKernelFuncToCubin.cpp
-
- DEPENDS
- MLIRConversionPassIncGen
- intrinsics_gen
-
- LINK_COMPONENTS
- Core
- ${NVPTX_LIBS}
-
- LINK_LIBS PUBLIC
- MLIRGPU
- MLIRIR
- MLIRLLVMIR
- MLIRNVVMIR
- MLIRPass
- MLIRSupport
- MLIRTargetNVVMIR
- )
-else()
- add_library(MLIRGPUtoCUDATransforms INTERFACE IMPORTED GLOBAL)
-endif()
diff --git a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp b/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
deleted file mode 100644
index 3f99c56c4716..000000000000
--- a/mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
+++ /dev/null
@@ -1,165 +0,0 @@
-//===- ConvertKernelFuncToCubin.cpp - MLIR GPU lowering passes ------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file implements a pass to convert gpu kernel functions into a
-// corresponding binary blob that can be executed on a CUDA GPU. Currently
-// only translates the function itself but no dependencies.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
-
-#include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Module.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Pass/PassRegistry.h"
-#include "mlir/Support/LogicalResult.h"
-#include "mlir/Target/NVVMIR.h"
-
-#include "llvm/ADT/Optional.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/LegacyPassManager.h"
-#include "llvm/IR/Module.h"
-#include "llvm/Support/Error.h"
-#include "llvm/Support/Mutex.h"
-#include "llvm/Support/TargetRegistry.h"
-#include "llvm/Support/TargetSelect.h"
-#include "llvm/Target/TargetMachine.h"
-
-using namespace mlir;
-
-namespace {
-// TODO(herhut): Move to shared location.
-static constexpr const char *kCubinAnnotation = "nvvm.cubin";
-
-/// A pass converting tagged kernel modules to cubin blobs.
-///
-/// If tagged as a kernel module, each contained function is translated to NVVM
-/// IR and further to PTX. A user provided CubinGenerator compiles the PTX to
-/// GPU binary code, which is then attached as an attribute to the function. The
-/// function body is erased.
-class GpuKernelToCubinPass
- : public PassWrapper<GpuKernelToCubinPass,
- OperationPass<gpu::GPUModuleOp>> {
-public:
- GpuKernelToCubinPass(CubinGenerator cubinGenerator)
- : cubinGenerator(cubinGenerator) {}
-
- void runOnOperation() override {
- gpu::GPUModuleOp module = getOperation();
-
- // Lock access to the llvm context.
- llvm::sys::SmartScopedLock<true> scopedLock(
- module.getContext()
- ->getRegisteredDialect<LLVM::LLVMDialect>()
- ->getLLVMContextMutex());
-
- // Make sure the NVPTX target is initialized.
- LLVMInitializeNVPTXTarget();
- LLVMInitializeNVPTXTargetInfo();
- LLVMInitializeNVPTXTargetMC();
- LLVMInitializeNVPTXAsmPrinter();
-
- auto llvmModule = translateModuleToNVVMIR(module);
- if (!llvmModule)
- return signalPassFailure();
-
- // Translate the module to CUBIN and attach the result as attribute to the
- // module.
- if (auto cubinAttr = translateGPUModuleToCubinAnnotation(
- *llvmModule, module.getLoc(), module.getName()))
- module.setAttr(kCubinAnnotation, cubinAttr);
- else
- signalPassFailure();
- }
-
-private:
- std::string translateModuleToPtx(llvm::Module &module,
- llvm::TargetMachine &target_machine);
-
- /// Converts llvmModule to cubin using the user-provided generator. Location
- /// is used for error reporting and name is forwarded to the CUBIN generator
- /// to use in its logging mechanisms.
- OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, Location loc,
- StringRef name);
-
- /// Translates llvmModule to cubin and returns the result as attribute.
- StringAttr translateGPUModuleToCubinAnnotation(llvm::Module &llvmModule,
- Location loc, StringRef name);
-
- CubinGenerator cubinGenerator;
-};
-
-} // anonymous namespace
-
-std::string GpuKernelToCubinPass::translateModuleToPtx(
- llvm::Module &module, llvm::TargetMachine &target_machine) {
- std::string ptx;
- {
- // Clone the llvm module into a new context to enable concurrent compilation
- // with multiple threads.
- // TODO(zinenko): Reevaluate model of ownership of LLVMContext in
- // LLVMDialect.
- llvm::LLVMContext llvmContext;
- auto clone = LLVM::cloneModuleIntoNewContext(&llvmContext, &module);
-
- llvm::raw_string_ostream stream(ptx);
- llvm::buffer_ostream pstream(stream);
- llvm::legacy::PassManager codegen_passes;
- target_machine.addPassesToEmitFile(codegen_passes, pstream, nullptr,
- llvm::CGFT_AssemblyFile);
- codegen_passes.run(*clone);
- }
-
- return ptx;
-}
-
-OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule,
- Location loc,
- StringRef name) {
- std::unique_ptr<llvm::TargetMachine> targetMachine;
- {
- std::string error;
- // TODO(herhut): Make triple configurable.
- constexpr const char *cudaTriple = "nvptx64-nvidia-cuda";
- llvm::Triple triple(cudaTriple);
- const llvm::Target *target =
- llvm::TargetRegistry::lookupTarget("", triple, error);
- if (target == nullptr) {
- emitError(loc, "cannot initialize target triple");
- return {};
- }
- targetMachine.reset(
- target->createTargetMachine(triple.str(), "sm_35", "+ptx60", {}, {}));
- }
-
- // Set the data layout of the llvm module to match what the ptx target needs.
- llvmModule.setDataLayout(targetMachine->createDataLayout());
-
- auto ptx = translateModuleToPtx(llvmModule, *targetMachine);
-
- return cubinGenerator(ptx, loc, name);
-}
-
-StringAttr GpuKernelToCubinPass::translateGPUModuleToCubinAnnotation(
- llvm::Module &llvmModule, Location loc, StringRef name) {
- auto cubin = convertModuleToCubin(llvmModule, loc, name);
- if (!cubin)
- return {};
- return StringAttr::get({cubin->data(), cubin->size()}, loc->getContext());
-}
-
-std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
-mlir::createConvertGPUKernelToCubinPass(CubinGenerator cubinGenerator) {
- return std::make_unique<GpuKernelToCubinPass>(cubinGenerator);
-}
diff --git a/mlir/test/Conversion/GPUToROCm/lit.local.cfg b/mlir/test/Conversion/GPUToROCm/lit.local.cfg
new file mode 100644
index 000000000000..6eb561783b3f
--- /dev/null
+++ b/mlir/test/Conversion/GPUToROCm/lit.local.cfg
@@ -0,0 +1,2 @@
+if not config.run_rocm_tests:
+ config.unsupported = True
diff --git a/mlir/test/Conversion/GPUToROCm/lower-rocdl-kernel-to-hsaco.mlir b/mlir/test/Conversion/GPUToROCm/lower-rocdl-kernel-to-hsaco.mlir
new file mode 100644
index 000000000000..5ee3bb21aa91
--- /dev/null
+++ b/mlir/test/Conversion/GPUToROCm/lower-rocdl-kernel-to-hsaco.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --test-kernel-to-hsaco -split-input-file | FileCheck %s
+
+// CHECK: attributes {rocdl.hsaco = "HSACO"}
+gpu.module @foo {
+ llvm.func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">)
+ // CHECK: attributes {gpu.kernel}
+ attributes { gpu.kernel } {
+ llvm.return
+ }
+}
+
+// -----
+
+gpu.module @bar {
+ // CHECK: func @kernel_a
+ llvm.func @kernel_a()
+ attributes { gpu.kernel } {
+ llvm.return
+ }
+
+ // CHECK: func @kernel_b
+ llvm.func @kernel_b()
+ attributes { gpu.kernel } {
+ llvm.return
+ }
+}
diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt
index d040cdf97abb..55bf84cb1637 100644
--- a/mlir/test/lib/Transforms/CMakeLists.txt
+++ b/mlir/test/lib/Transforms/CMakeLists.txt
@@ -1,3 +1,21 @@
+if (MLIR_CUDA_CONVERSIONS_ENABLED)
+ set(NVPTX_LIBS
+ MC
+ NVPTXCodeGen
+ NVPTXDesc
+ NVPTXInfo
+ )
+endif()
+
+if (MLIR_ROCM_CONVERSIONS_ENABLED)
+ set(AMDGPU_LIBS
+ MC
+ AMDGPUCodeGen
+ AMDGPUDesc
+ AMDGPUInfo
+ )
+endif()
+
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestTransforms
TestAllReduceLowering.cpp
@@ -5,6 +23,7 @@ add_mlir_library(MLIRTestTransforms
TestCallGraph.cpp
TestConstantFold.cpp
TestConvertGPUKernelToCubin.cpp
+ TestConvertGPUKernelToHsaco.cpp
TestDominance.cpp
TestLoopFusion.cpp
TestGpuMemoryPromotion.cpp
@@ -31,18 +50,26 @@ add_mlir_library(MLIRTestTransforms
MLIRStandardOpsIncGen
MLIRTestVectorTransformPatternsIncGen
+ LINK_COMPONENTS
+ ${AMDGPU_LIBS}
+ ${NVPTX_LIBS}
+
LINK_LIBS PUBLIC
MLIRAffineOps
MLIRAnalysis
MLIREDSC
MLIRGPU
- MLIRGPUtoCUDATransforms
+ MLIRGPUtoGPURuntimeTransforms
MLIRLinalgOps
MLIRLinalgTransforms
+ MLIRNVVMIR
MLIRSCF
MLIRGPU
MLIRPass
+ MLIRROCDLIR
MLIRStandardOpsTransforms
+ MLIRTargetNVVMIR
+ MLIRTargetROCDLIR
MLIRTestDialect
MLIRTransformUtils
MLIRVectorToSCF
diff --git a/mlir/test/lib/Transforms/TestConvertGPUKernelToCubin.cpp b/mlir/test/lib/Transforms/TestConvertGPUKernelToCubin.cpp
index e0c4c1907c4f..a347b2c28031 100644
--- a/mlir/test/lib/Transforms/TestConvertGPUKernelToCubin.cpp
+++ b/mlir/test/lib/Transforms/TestConvertGPUKernelToCubin.cpp
@@ -6,26 +6,36 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Target/NVVMIR.h"
+#include "llvm/Support/TargetSelect.h"
using namespace mlir;
#if MLIR_CUDA_CONVERSIONS_ENABLED
-static OwnedCubin compilePtxToCubinForTesting(const std::string &, Location,
- StringRef) {
+static OwnedBlob compilePtxToCubinForTesting(const std::string &, Location,
+ StringRef) {
const char data[] = "CUBIN";
return std::make_unique<std::vector<char>>(data, data + sizeof(data) - 1);
}
namespace mlir {
void registerTestConvertGPUKernelToCubinPass() {
- PassPipelineRegistration<>("test-kernel-to-cubin",
- "Convert all kernel functions to CUDA cubin blobs",
- [](OpPassManager &pm) {
- pm.addPass(createConvertGPUKernelToCubinPass(
- compilePtxToCubinForTesting));
- });
+ PassPipelineRegistration<>(
+ "test-kernel-to-cubin",
+ "Convert all kernel functions to CUDA cubin blobs",
+ [](OpPassManager &pm) {
+ // Initialize LLVM NVPTX backend.
+ LLVMInitializeNVPTXTarget();
+ LLVMInitializeNVPTXTargetInfo();
+ LLVMInitializeNVPTXTargetMC();
+ LLVMInitializeNVPTXAsmPrinter();
+
+ pm.addPass(createConvertGPUKernelToBlobPass(
+ translateModuleToNVVMIR, compilePtxToCubinForTesting,
+ "nvptx64-nvidia-cuda", "sm_35", "+ptx60", "nvvm.cubin"));
+ });
}
} // namespace mlir
#endif
diff --git a/mlir/test/lib/Transforms/TestConvertGPUKernelToHsaco.cpp b/mlir/test/lib/Transforms/TestConvertGPUKernelToHsaco.cpp
new file mode 100644
index 000000000000..54293a8099b4
--- /dev/null
+++ b/mlir/test/lib/Transforms/TestConvertGPUKernelToHsaco.cpp
@@ -0,0 +1,41 @@
+//===- TestConvertGPUKernelToHsaco.cpp - Test gpu kernel hsaco lowering ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Target/ROCDLIR.h"
+#include "llvm/Support/TargetSelect.h"
+using namespace mlir;
+
+#if MLIR_ROCM_CONVERSIONS_ENABLED
+static OwnedBlob compileIsaToHsacoForTesting(const std::string &, Location,
+ StringRef) {
+ const char data[] = "HSACO";
+ return std::make_unique<std::vector<char>>(data, data + sizeof(data) - 1);
+}
+
+namespace mlir {
+void registerTestConvertGPUKernelToHsacoPass() {
+ PassPipelineRegistration<>(
+ "test-kernel-to-hsaco",
+ "Convert all kernel functions to ROCm hsaco blobs",
+ [](OpPassManager &pm) {
+ // Initialize LLVM AMDGPU backend.
+ LLVMInitializeAMDGPUTarget();
+ LLVMInitializeAMDGPUTargetInfo();
+ LLVMInitializeAMDGPUTargetMC();
+ LLVMInitializeAMDGPUAsmPrinter();
+
+ pm.addPass(createConvertGPUKernelToBlobPass(
+ translateModuleToROCDLIR, compileIsaToHsacoForTesting,
+ "amdgcn-amd-amdhsa", "gfx900", "-code-object-v3", "rocdl.hsaco"));
+ });
+}
+} // namespace mlir
+#endif
diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index dc6286a827bb..e07acf4d21a8 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -38,6 +38,7 @@ config.build_examples = @LLVM_BUILD_EXAMPLES@
config.run_cuda_tests = @MLIR_CUDA_CONVERSIONS_ENABLED@
config.cuda_wrapper_library_dir = "@MLIR_CUDA_WRAPPER_LIBRARY_DIR@"
config.enable_cuda_runner = @MLIR_CUDA_RUNNER_ENABLED@
+config.run_rocm_tests = @MLIR_ROCM_CONVERSIONS_ENABLED@
config.vulkan_wrapper_library_dir = "@MLIR_VULKAN_WRAPPER_LIBRARY_DIR@"
config.enable_vulkan_runner = @MLIR_VULKAN_RUNNER_ENABLED@
diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
index 6a404221744b..cdd8ec3fe5f3 100644
--- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
+++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
@@ -15,7 +15,6 @@
#include "llvm/ADT/STLExtras.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
-#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
@@ -30,6 +29,7 @@
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Target/NVVMIR.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/InitLLVM.h"
@@ -57,8 +57,8 @@ inline void emit_cuda_error(const llvm::Twine &message, const char *buffer,
} \
}
-OwnedCubin compilePtxToCubin(const std::string ptx, Location loc,
- StringRef name) {
+OwnedBlob compilePtxToCubin(const std::string ptx, Location loc,
+ StringRef name) {
char jitErrorBuffer[4096] = {0};
RETURN_ON_CUDA_ERROR(cuInit(0), "cuInit");
@@ -97,7 +97,7 @@ OwnedCubin compilePtxToCubin(const std::string ptx, Location loc,
"cuLinkComplete");
char *cubinAsChar = static_cast<char *>(cubinData);
- OwnedCubin result =
+ OwnedBlob result =
std::make_unique<std::vector<char>>(cubinAsChar, cubinAsChar + cubinSize);
// This will also destroy the cubin data.
@@ -114,7 +114,9 @@ static LogicalResult runMLIRPasses(ModuleOp m) {
auto &kernelPm = pm.nest<gpu::GPUModuleOp>();
kernelPm.addPass(createStripDebugInfoPass());
kernelPm.addPass(createLowerGpuOpsToNVVMOpsPass());
- kernelPm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));
+ kernelPm.addPass(createConvertGPUKernelToBlobPass(
+ translateModuleToNVVMIR, compilePtxToCubin, "nvptx64-nvidia-cuda",
+ "sm_35", "+ptx60", "nvvm.cubin"));
pm.addPass(createLowerToLLVMPass());
pm.addPass(createConvertGpuLaunchFuncToGpuRuntimeCallsPass());
@@ -127,6 +129,13 @@ int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
+
+ // Initialize LLVM NVPTX backend.
+ LLVMInitializeNVPTXTarget();
+ LLVMInitializeNVPTXTargetInfo();
+ LLVMInitializeNVPTXTargetMC();
+ LLVMInitializeNVPTXAsmPrinter();
+
mlir::initializeLLVMPasses();
return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
}
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 69b1d8d57bc5..159a7fd4bca5 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -46,6 +46,7 @@ void registerTestLoopPermutationPass();
void registerTestCallGraphPass();
void registerTestConstantFold();
void registerTestConvertGPUKernelToCubinPass();
+void registerTestConvertGPUKernelToHsacoPass();
void registerTestDominancePass();
void registerTestFunc();
void registerTestGpuMemoryPromotionPass();
@@ -112,6 +113,9 @@ void registerTestPasses() {
registerTestConstantFold();
#if MLIR_CUDA_CONVERSIONS_ENABLED
registerTestConvertGPUKernelToCubinPass();
+#endif
+#if MLIR_ROCM_CONVERSIONS_ENABLED
+ registerTestConvertGPUKernelToHsacoPass();
#endif
registerTestBufferPlacementPreparationPass();
registerTestDominancePass();
More information about the Mlir-commits
mailing list