[Mlir-commits] [mlir] b43068e - [mlir][gpu] Update GPU translation to accept binaries.

Fabian Mora llvmlistbot at llvm.org
Fri Aug 11 17:29:49 PDT 2023


Author: Fabian Mora
Date: 2023-08-12T00:29:42Z
New Revision: b43068e8707dea0ad601377b3133b4abe89d370a

URL: https://github.com/llvm/llvm-project/commit/b43068e8707dea0ad601377b3133b4abe89d370a
DIFF: https://github.com/llvm/llvm-project/commit/b43068e8707dea0ad601377b3133b4abe89d370a.diff

LOG: [mlir][gpu] Update GPU translation to accept binaries.

== Commit message ==
Modifies GPU translation to accept GPU binaries embedding them using the
object manager interface method `embedBinary`, as well as accepting kernel
launch operations translating them using the interface method `launchKernel`.

Depends on D154152

= Explanation =
**Summary:**
These patches aim to be a replacement to the current GPU compilation infrastructure, with extensibility and trying to minimizing future disruption as the primary goal.
The biggest updates performed by these patches are:
 - The introduction of Target attributes, these attributes handle compilation of GPU modules into binary strings. These attributes can be implemented by any dialect, leaving the option for downstream users to implement their own serializations.
 - The introduction of the GPU binary operation, this operation stores GPU objects for different targets and can be invoked by `gpu.launch_func`.
 - Making `gpu.binary` & `gpu.launch_func` translatable to LLVM IR, with the translation being controlled by Object Manager attributes.
 - The introduction of the `gpu-module-to-binary` pass. This pass serializes GPU modules into GPU binaries, using the GPU targets available in the module.
 - The introduction of the `#gpu.select_object` object manager as the default object manager, it selects a single object for embedding in the IR, by default it selects the first object.

These patches leave the current infrastructure in place, allowing for a migration period for downstream users.

**Examples:**
- GPU modules using target attributes:
```
gpu.module @my_module [#gpu.nvptx<chip = "sm_90">, #gpu.amdgpu, #gpu.amdgpu<chip = "gfx90a">] {
...
}
```
- Applying the `gpu-module-to-binary` pass:
```
gpu.module @my_module [#gpu.nvptx<chip = "sm_90">, #gpu.amdgpu] {
...
}
; mlir-opt --gpu-module-to-binary
gpu.binary @my_module [#gpu.object<#gpu.nvptx<chip = "sm_90">, "BINARY DATA">, #gpu.object<#gpu.amdgpu, "BINARY DATA">]
```
- Choosing the `#gpu.amdgpu` object for embedding:
```
gpu.binary @my_module <#gpu.select_object<#gpu.amdgpu>> [#gpu.object<#gpu.nvptx<chip = "sm_90">, "BINARY DATA">, #gpu.object<#gpu.amdgpu, "BINARY DATA">]
; It's also valid to pass the index of the object.
gpu.binary @my_module <#gpu.select_object<1>> [#gpu.object<#gpu.nvptx<chip = "sm_90">, "BINARY DATA">, #gpu.object<#gpu.amdgpu, "BINARY DATA">]
```

**Testing:**
This infrastructure was tested in 2 systems, one with a NVIDIA V100 and the other one with a AMD MI250X, in both cases the test completion was successful.

Input files:
 - **test.cpp**  {F28084155}
 - **test_nvvm.mlir** {F28084157}
 - **test_rocdl.mlir** {F28084162}

1.  Steps for assembling the test for the NVIDIA system:
```
mlir-opt --gpu-to-llvm --gpu-module-to-binary test_nvvm.mlir | mlir-translate --mlir-to-llvmir -o test_nvptx.ll
clang++ test_nvptx.ll test.cpp -l
```
**Output file:** test_nvptx.ll {F28084210}

2.  Steps for assembling the test for the AMD system:
```
mlir-opt --gpu-to-llvm --gpu-module-to-binary test_rocdl.mlir | mlir-translate --mlir-to-llvmir -o test_amdgpu.ll
clang++ test_amdgpu.ll test.cpp -l
```
**Output file:** test_amdgpu.ll {F28084217}

== Diff list ==
The following patches implement the proposal described in: https://discourse.llvm.org/t/rfc-extending-mlir-gpu-device-codegen-pipeline/70199/54 :
 - D154098: Add a `GlobalSymbol` trait.
 - D154097: Add a parameter for passing default values to `StringRefParameter`
 - D154100: Adds an utility class for serializing operations to binary strings.
 - D154104: Add GPU target attribute interface.
 - D154113: Add target attribute to GPU modules.
 - D154117: Adds the NVPTX target attribute.
 - D154129: Adds the AMDGPU target attribute.
 - D154108: Add the GPU object manager attribute interface.
 - D154132: Add `gpu.binary` op and `#gpu.object` attribute.
 - D154137: Modifies `gpu.launch_func` to allow lowering it after gpu-to-llvm.
 - D154147: Add the Select Object compilation attribute.
 - D154149: Add the `gpu-module-to-binary` pass.
 - D154152: Add GPU target support to `gpu-to-llvm`.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D154153

Added: 
    mlir/test/Target/LLVMIR/gpu.mlir

Modified: 
    mlir/lib/Target/LLVMIR/CMakeLists.txt
    mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
    mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index 868ccbbb10620d..dd97ccf8868863 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -57,6 +57,8 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
   MLIROpenACCToLLVMIRTranslation
   MLIROpenMPToLLVMIRTranslation
   MLIRROCDLToLLVMIRTranslation
+  MLIRNVVMTarget
+  MLIRROCDLTarget
   )
 
 add_mlir_translation_library(MLIRTargetLLVMIRImport

diff  --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
index 45588937795348..b7c1c40e13126a 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
@@ -13,6 +13,8 @@
 #include "mlir/Dialect/DLTI/DLTI.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/Target/LLVM/NVVM/Target.h"
+#include "mlir/Target/LLVM/ROCDL/Target.h"
 #include "mlir/Target/LLVMIR/Dialect/All.h"
 #include "mlir/Target/LLVMIR/Export.h"
 #include "mlir/Tools/mlir-translate/Translation.h"
@@ -36,6 +38,8 @@ void registerToLLVMIRTranslation() {
       },
       [](DialectRegistry &registry) {
         registry.insert<DLTIDialect, func::FuncDialect>();
+        registerNVVMTarget(registry);
+        registerROCDLTarget(registry);
         registerAllToLLVMIRTranslations(registry);
       });
 }

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp
index 3e677bcc5d7f6b..ef98f737f07315 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp
@@ -12,10 +12,28 @@
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 
 namespace {
+LogicalResult launchKernel(gpu::LaunchFuncOp launchOp,
+                           llvm::IRBuilderBase &builder,
+                           LLVM::ModuleTranslation &moduleTranslation) {
+  auto kernelBinary = SymbolTable::lookupNearestSymbolFrom<gpu::BinaryOp>(
+      launchOp, launchOp.getKernelModuleName());
+  if (!kernelBinary) {
+    launchOp.emitError("Couldn't find the binary holding the kernel: ")
+        << launchOp.getKernelModuleName();
+    return failure();
+  }
+  auto offloadingHandler =
+      dyn_cast<gpu::OffloadingLLVMTranslationAttrInterface>(
+          kernelBinary.getOffloadingHandlerAttr());
+  assert(offloadingHandler && "Invalid offloading handler.");
+  return offloadingHandler.launchKernel(launchOp, kernelBinary, builder,
+                                        moduleTranslation);
+}
 
 class GPUDialectLLVMIRTranslationInterface
     : public LLVMTranslationDialectInterface {
@@ -23,9 +41,23 @@ class GPUDialectLLVMIRTranslationInterface
   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
 
   LogicalResult
-  convertOperation(Operation *op, llvm::IRBuilderBase &builder,
+  convertOperation(Operation *operation, llvm::IRBuilderBase &builder,
                    LLVM::ModuleTranslation &moduleTranslation) const override {
-    return isa<gpu::GPUModuleOp>(op) ? success() : failure();
+    return llvm::TypeSwitch<Operation *, LogicalResult>(operation)
+        .Case([&](gpu::GPUModuleOp) { return success(); })
+        .Case([&](gpu::BinaryOp op) {
+          auto offloadingHandler =
+              dyn_cast<gpu::OffloadingLLVMTranslationAttrInterface>(
+                  op.getOffloadingHandlerAttr());
+          assert(offloadingHandler && "Invalid offloading handler.");
+          return offloadingHandler.embedBinary(op, builder, moduleTranslation);
+        })
+        .Case([&](gpu::LaunchFuncOp op) {
+          return launchKernel(op, builder, moduleTranslation);
+        })
+        .Default([&](Operation *op) {
+          return op->emitError("unsupported GPU operation: ") << op->getName();
+        });
   }
 };
 

diff  --git a/mlir/test/Target/LLVMIR/gpu.mlir b/mlir/test/Target/LLVMIR/gpu.mlir
new file mode 100644
index 00000000000000..fddbbee962c1ae
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/gpu.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+// Checking the translation of the `gpu.binary` & `gpu.launch_fun` ops.
+module attributes {gpu.container_module} {
+  // CHECK: [[ARGS_TY:%.*]] = type { i32, i32 }
+  // CHECK: @kernel_module_bin_cst = internal constant [4 x i8] c"BLOB", align 8
+  // CHECK: @kernel_module_kernel_kernel_name = private unnamed_addr constant [7 x i8] c"kernel\00", align 1
+  gpu.binary @kernel_module  [#gpu.object<#nvvm.target, "BLOB">]
+  llvm.func @foo() {
+    // CHECK: [[ARGS:%.*]] = alloca %{{.*}}, align 8
+    // CHECK: [[ARGS_ARRAY:%.*]] = alloca ptr, i64 2, align 8
+    // CHECK: [[ARG0:%.*]] = getelementptr inbounds [[ARGS_TY]], ptr [[ARGS]], i32 0, i32 0
+    // CHECK: store i32 32, ptr [[ARG0]], align 4
+    // CHECK: %{{.*}} = getelementptr ptr, ptr [[ARGS_ARRAY]], i32 0
+    // CHECK: store ptr [[ARG0]], ptr %{{.*}}, align 8
+    // CHECK: [[ARG1:%.*]] = getelementptr inbounds [[ARGS_TY]], ptr [[ARGS]], i32 0, i32 1
+    // CHECK: store i32 32, ptr [[ARG1]], align 4
+    // CHECK: %{{.*}} = getelementptr ptr, ptr [[ARGS_ARRAY]], i32 1
+    // CHECK: store ptr [[ARG1]], ptr %{{.*}}, align 8
+    // CHECK: [[MODULE:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst)
+    // CHECK: [[FUNC:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[MODULE]], ptr @kernel_module_kernel_kernel_name)
+    // CHECK: [[STREAM:%.*]] = call ptr @mgpuStreamCreate()
+    // CHECK: call void @mgpuLaunchKernel(ptr [[FUNC]], i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i32 256, ptr [[STREAM]], ptr [[ARGS_ARRAY]], ptr null)
+    // CHECK: call void @mgpuStreamSynchronize(ptr [[STREAM]])
+    // CHECK: call void @mgpuStreamDestroy(ptr [[STREAM]])
+    // CHECK: call void @mgpuModuleUnload(ptr [[MODULE]])
+    %0 = llvm.mlir.constant(8 : index) : i64
+    %1 = llvm.mlir.constant(32 : i32) : i32
+    %2 = llvm.mlir.constant(256 : i32) : i32
+    gpu.launch_func @kernel_module::@kernel blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %2 args(%1 : i32, %1 : i32)
+    llvm.return
+  }
+}
+
+// -----
+
+// Checking the correct selection of the second object using an index as a selector.
+module {
+  // CHECK: @kernel_module_bin_cst = internal constant [1 x i8] c"1", align 8
+  gpu.binary @kernel_module <#gpu.select_object<1>> [#gpu.object<#nvvm.target, "0">, #gpu.object<#nvvm.target, "1">]
+}
+
+// -----
+
+// Checking the correct selection of the second object using a target as a selector.
+module {
+  // CHECK: @kernel_module_bin_cst = internal constant [6 x i8] c"AMDGPU", align 8
+  gpu.binary @kernel_module <#gpu.select_object<#rocdl.target>> [#gpu.object<#nvvm.target, "NVPTX">, #gpu.object<#rocdl.target, "AMDGPU">]
+}
+
+// -----
+
+// Checking the translation of `gpu.launch_fun` with an async dependency.
+module attributes {gpu.container_module} {
+  // CHECK: @kernel_module_bin_cst = internal constant [4 x i8] c"BLOB", align 8
+  gpu.binary @kernel_module  [#gpu.object<#rocdl.target, "BLOB">]
+  llvm.func @foo() {
+    %0 = llvm.mlir.constant(8 : index) : i64
+    // CHECK: = call ptr @mgpuStreamCreate()
+    // CHECK-NEXT: = alloca {{.*}}, align 8
+    // CHECK-NEXT: [[ARGS:%.*]] = alloca ptr, i64 0, align 8
+    // CHECK-NEXT: [[MODULE:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst)
+    // CHECK-NEXT: [[FUNC:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[MODULE]], ptr @kernel_module_kernel_kernel_name)
+    // CHECK-NEXT: call void @mgpuLaunchKernel(ptr [[FUNC]], i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i32 0, ptr {{.*}}, ptr [[ARGS]], ptr null)
+    // CHECK-NEXT: call void @mgpuModuleUnload(ptr [[MODULE]])
+    // CHECK-NEXT: call void @mgpuStreamSynchronize(ptr %{{.*}})
+    // CHECK-NEXT: call void @mgpuStreamDestroy(ptr %{{.*}})
+    %1 = llvm.call @mgpuStreamCreate() : () -> !llvm.ptr
+    gpu.launch_func <%1 : !llvm.ptr> @kernel_module::@kernel blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64
+    llvm.call @mgpuStreamSynchronize(%1) : (!llvm.ptr) -> ()
+    llvm.call @mgpuStreamDestroy(%1) : (!llvm.ptr) -> ()
+    llvm.return
+  }
+  llvm.func @mgpuStreamCreate() -> !llvm.ptr
+  llvm.func @mgpuStreamSynchronize(!llvm.ptr)
+  llvm.func @mgpuStreamDestroy(!llvm.ptr)
+}


        


More information about the Mlir-commits mailing list