[Mlir-commits] [mlir] 43752a2 - [mlir][gpu] Add the `gpu-module-to-binary` pass.
Fabian Mora
llvmlistbot at llvm.org
Fri Aug 11 17:25:01 PDT 2023
Author: Fabian Mora
Date: 2023-08-12T00:24:53Z
New Revision: 43752a2aa31a28352be6f38be0ac8108394f6d1d
URL: https://github.com/llvm/llvm-project/commit/43752a2aa31a28352be6f38be0ac8108394f6d1d
DIFF: https://github.com/llvm/llvm-project/commit/43752a2aa31a28352be6f38be0ac8108394f6d1d.diff
LOG: [mlir][gpu] Add the `gpu-module-to-binary` pass.
**For an explanation of these patches see D154153.**
Commit message:
This pass converts GPU modules into GPU binaries, serializing all targets present
in a GPU module by invoking the `serializeToObject` target attribute method.
Depends on D154147
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D154149
Added:
mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
mlir/test/Dialect/GPU/module-to-binary-nvvm.mlir
mlir/test/Dialect/GPU/module-to-binary-rocdl.mlir
Modified:
mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
mlir/lib/Dialect/GPU/CMakeLists.txt
mlir/test/lit.cfg.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 970dfea4677d83..f0675dade3a8b6 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -70,6 +70,13 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
}
namespace gpu {
+/// Searches for all GPU modules in `op` and transforms them into GPU binary
+/// operations. The resulting `gpu.binary` has `handler` as its offloading
+/// handler attribute.
+LogicalResult transformGpuModulesToBinaries(
+ Operation *op, OffloadingLLVMTranslationAttrInterface handler = nullptr,
+ const gpu::TargetOptions &options = {});
+
/// Base pass class to serialize kernel functions through LLVM into
/// user-specified IR and add the resulting blob as module attribute.
class SerializeToBlobPass : public OperationPass<gpu::GPUModuleOp> {
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 7602f8bcc6a482..b3989009292bff 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -55,4 +55,31 @@ def GpuDecomposeMemrefsPass : Pass<"gpu-decompose-memrefs"> {
];
}
+def GpuModuleToBinaryPass
+ : Pass<"gpu-module-to-binary", ""> {
+ let summary = "Transforms a GPU module into a GPU binary.";
+ let description = [{
+ This pass searches for all nested GPU modules and serializes the module
+ using the target attributes attached to the module, producing a GPU binary
+ with an object for every target.
+
+ The `format` argument can have the following values:
+ 1. `offloading`, `llvm`: producing an offloading representation.
+ 2. `assembly`, `isa`: producing assembly code.
+ 3. `binary`, `bin`: producing binaries.
+ }];
+ let options = [
+ Option<"offloadingHandler", "handler", "Attribute", "nullptr",
+ "Offloading handler to be attached to the resulting binary op.">,
+ Option<"toolkitPath", "toolkit", "std::string", [{""}],
+ "Toolkit path.">,
+ ListOption<"linkFiles", "l", "std::string",
+ "Extra files to link to.">,
+ Option<"cmdOptions", "opts", "std::string", [{""}],
+ "Command line options to pass to the tools.">,
+ Option<"compilationTarget", "format", "std::string", [{"bin"}],
+ "The target representation of the compilation process.">
+ ];
+}
+
#endif // MLIR_DIALECT_GPU_PASSES
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 81d7bf96bbf4c9..00b66f9699a3d9 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -51,6 +51,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/GlobalIdRewriter.cpp
Transforms/KernelOutlining.cpp
Transforms/MemoryPromotion.cpp
+ Transforms/ModuleToBinary.cpp
Transforms/ParallelLoopMapper.cpp
Transforms/SerializeToBlob.cpp
Transforms/SerializeToCubin.cpp
@@ -85,10 +86,12 @@ add_mlir_dialect_library(MLIRGPUTransforms
MLIRGPUToLLVMIRTranslation
MLIRLLVMToLLVMIRTranslation
MLIRMemRefDialect
+ MLIRNVVMTarget
MLIRPass
MLIRSCFDialect
MLIRSideEffectInterfaces
MLIRSupport
+ MLIRROCDLTarget
MLIRTransformUtils
)
diff --git a/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
new file mode 100644
index 00000000000000..4888c8f79cc1f6
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
@@ -0,0 +1,122 @@
+//===- ModuleToBinary.cpp - Transforms GPU modules to GPU binaries ----------=//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the `GpuModuleToBinaryPass` pass, transforming GPU
+// modules into GPU binaries.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Target/LLVM/NVVM/Target.h"
+#include "mlir/Target/LLVM/ROCDL/Target.h"
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/StringSwitch.h"
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+namespace mlir {
+#define GEN_PASS_DEF_GPUMODULETOBINARYPASS
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+namespace {
+class GpuModuleToBinaryPass
+ : public impl::GpuModuleToBinaryPassBase<GpuModuleToBinaryPass> {
+public:
+ using Base::Base;
+ void getDependentDialects(DialectRegistry ®istry) const override;
+ void runOnOperation() final;
+};
+} // namespace
+
+void GpuModuleToBinaryPass::getDependentDialects(
+ DialectRegistry ®istry) const {
+ // Register all GPU related translations.
+ registerLLVMDialectTranslation(registry);
+ registerGPUDialectTranslation(registry);
+#if MLIR_CUDA_CONVERSIONS_ENABLED == 1
+ registerNVVMTarget(registry);
+#endif
+#if MLIR_ROCM_CONVERSIONS_ENABLED == 1
+ registerROCDLTarget(registry);
+#endif
+}
+
+void GpuModuleToBinaryPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ int targetFormat = llvm::StringSwitch<int>(compilationTarget)
+ .Cases("offloading", "llvm", TargetOptions::offload)
+ .Cases("assembly", "isa", TargetOptions::assembly)
+ .Cases("binary", "bin", TargetOptions::binary)
+ .Default(-1);
+ if (targetFormat == -1)
+ getOperation()->emitError() << "Invalid format specified.";
+ TargetOptions targetOptions(
+ toolkitPath, linkFiles, cmdOptions,
+ static_cast<TargetOptions::CompilationTarget>(targetFormat));
+ if (failed(transformGpuModulesToBinaries(
+ getOperation(),
+ offloadingHandler ? dyn_cast<OffloadingLLVMTranslationAttrInterface>(
+ offloadingHandler.getValue())
+ : OffloadingLLVMTranslationAttrInterface(nullptr),
+ targetOptions)))
+ return signalPassFailure();
+}
+
+namespace {
+LogicalResult moduleSerializer(GPUModuleOp op,
+ OffloadingLLVMTranslationAttrInterface handler,
+ const TargetOptions &targetOptions) {
+ OpBuilder builder(op->getContext());
+ SmallVector<Attribute> objects;
+ // Serialize all targets.
+ for (auto targetAttr : op.getTargetsAttr()) {
+ assert(targetAttr && "Target attribute cannot be null.");
+ auto target = dyn_cast<gpu::TargetAttrInterface>(targetAttr);
+ assert(target &&
+ "Target attribute doesn't implements `TargetAttrInterface`.");
+ std::optional<SmallVector<char, 0>> object =
+ target.serializeToObject(op, targetOptions);
+
+ if (!object) {
+ op.emitError("An error happened while serializing the module.");
+ return failure();
+ }
+
+ objects.push_back(builder.getAttr<gpu::ObjectAttr>(
+ target,
+ builder.getStringAttr(StringRef(object->data(), object->size()))));
+ }
+ builder.setInsertionPointAfter(op);
+ builder.create<gpu::BinaryOp>(op.getLoc(), op.getName(), handler,
+ builder.getArrayAttr(objects));
+ op->erase();
+ return success();
+}
+} // namespace
+
+LogicalResult mlir::gpu::transformGpuModulesToBinaries(
+ Operation *op, OffloadingLLVMTranslationAttrInterface handler,
+ const gpu::TargetOptions &targetOptions) {
+ for (Region ®ion : op->getRegions())
+ for (Block &block : region.getBlocks())
+ for (auto module :
+ llvm::make_early_inc_range(block.getOps<GPUModuleOp>()))
+ if (failed(moduleSerializer(module, handler, targetOptions)))
+ return failure();
+ return success();
+}
diff --git a/mlir/test/Dialect/GPU/module-to-binary-nvvm.mlir b/mlir/test/Dialect/GPU/module-to-binary-nvvm.mlir
new file mode 100644
index 00000000000000..555b28a8293ee4
--- /dev/null
+++ b/mlir/test/Dialect/GPU/module-to-binary-nvvm.mlir
@@ -0,0 +1,25 @@
+// REQUIRES: host-supports-nvptx
+// RUN: mlir-opt %s --gpu-module-to-binary="format=llvm" | FileCheck %s
+// RUN: mlir-opt %s --gpu-module-to-binary="format=isa" | FileCheck %s
+
+module attributes {gpu.container_module} {
+ // CHECK-LABEL:gpu.binary @kernel_module1
+ // CHECK:[#gpu.object<#nvvm.target<chip = "sm_70">, "{{.*}}">]
+ gpu.module @kernel_module1 [#nvvm.target<chip = "sm_70">] {
+ llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr<f32>,
+ %arg2: !llvm.ptr<f32>, %arg3: i64, %arg4: i64,
+ %arg5: i64) attributes {gpu.kernel} {
+ llvm.return
+ }
+ }
+
+ // CHECK-LABEL:gpu.binary @kernel_module2
+ // CHECK:[#gpu.object<#nvvm.target<flags = {fast}>, "{{.*}}">, #gpu.object<#nvvm.target, "{{.*}}">]
+ gpu.module @kernel_module2 [#nvvm.target<flags = {fast}>, #nvvm.target] {
+ llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr<f32>,
+ %arg2: !llvm.ptr<f32>, %arg3: i64, %arg4: i64,
+ %arg5: i64) attributes {gpu.kernel} {
+ llvm.return
+ }
+ }
+}
diff --git a/mlir/test/Dialect/GPU/module-to-binary-rocdl.mlir b/mlir/test/Dialect/GPU/module-to-binary-rocdl.mlir
new file mode 100644
index 00000000000000..fb7cfb70c17ed3
--- /dev/null
+++ b/mlir/test/Dialect/GPU/module-to-binary-rocdl.mlir
@@ -0,0 +1,25 @@
+// REQUIRES: host-supports-amdgpu
+// RUN: mlir-opt %s --gpu-module-to-binary="format=llvm" | FileCheck %s
+// RUN: mlir-opt %s --gpu-module-to-binary="format=isa" | FileCheck %s
+
+module attributes {gpu.container_module} {
+ // CHECK-LABEL:gpu.binary @kernel_module1
+ // CHECK:[#gpu.object<#rocdl.target<chip = "gfx90a">, "{{.*}}">]
+ gpu.module @kernel_module1 [#rocdl.target<chip = "gfx90a">] {
+ llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr<f32>,
+ %arg2: !llvm.ptr<f32>, %arg3: i64, %arg4: i64,
+ %arg5: i64) attributes {gpu.kernel} {
+ llvm.return
+ }
+ }
+
+ // CHECK-LABEL:gpu.binary @kernel_module2
+ // CHECK:[#gpu.object<#rocdl.target<flags = {fast}>, "{{.*}}">, #gpu.object<#rocdl.target, "{{.*}}">]
+ gpu.module @kernel_module2 [#rocdl.target<flags = {fast}>, #rocdl.target] {
+ llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr<f32>,
+ %arg2: !llvm.ptr<f32>, %arg3: i64, %arg4: i64,
+ %arg5: i64) attributes {gpu.kernel} {
+ llvm.return
+ }
+ }
+}
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 3a8bdbfcec280c..fb99422c3ff5a3 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -210,3 +210,9 @@ def have_host_jit_feature_support(feature_name):
if have_host_jit_feature_support("jit"):
config.available_features.add("host-supports-jit")
+
+if config.run_cuda_tests:
+ config.available_features.add("host-supports-nvptx")
+
+if config.run_rocm_tests:
+ config.available_features.add("host-supports-amdgpu")
More information about the Mlir-commits
mailing list