[Mlir-commits] [mlir] 08b63db - [MLIR][GPU] Add GPU launch op support for dynamic shared memory

Uday Bondhugula llvmlistbot at llvm.org
Fri Oct 1 04:21:13 PDT 2021


Author: Uday Bondhugula
Date: 2021-10-01T16:46:07+05:30
New Revision: 08b63db8bb3ea847543351e1268be31ea327ad6b

URL: https://github.com/llvm/llvm-project/commit/08b63db8bb3ea847543351e1268be31ea327ad6b
DIFF: https://github.com/llvm/llvm-project/commit/08b63db8bb3ea847543351e1268be31ea327ad6b.diff

LOG: [MLIR][GPU] Add GPU launch op support for dynamic shared memory

Add support for dynamic shared memory for GPU launch ops: add an
optional operand to gpu.launch and gpu.launch_func ops to specify the
amount of "dynamic" shared memory to use. Update lowerings to connect
this operand to the GPU runtime.

Differential Revision: https://reviews.llvm.org/D110800

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/GPUOps.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
    mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
    mlir/test/Conversion/GPUToSPIRV/builtins.mlir
    mlir/test/Dialect/GPU/invalid.mlir
    mlir/test/Dialect/GPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index f3ce173108590..97359bce201ef 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -289,6 +289,7 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
                SymbolRefAttr:$kernel,
                Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
                Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ,
+               Optional<I32>:$dynamicSharedMemorySize,
                Variadic<AnyType>:$operands)>,
     Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
   let summary = "Launches a function as a GPU kernel";
@@ -317,7 +318,12 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
     dimensions as arguments. When a lower-dimensional kernel is required,
     unused sizes must be explicitly set to `1`.
 
-    The remaining operands are passed as arguments to the kernel function.
+    The remaining operands are optional. The first optional operand corresponds
+    to the amount of dynamic shared memory a kernel's workgroup should be
+    allocated; when this operand is not present, a zero size is assumed.
+
+    The remaining operands if present are passed as arguments to the kernel
+    function.
 
     Example:
 
@@ -360,6 +366,8 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
           @kernels::@kernel_1             // Kernel function.
           blocks in (%cst, %cst, %cst)    // Grid size.
           threads in (%cst, %cst, %cst)   // Block size.
+          dynamic_shared_memory_size %s   // (Optional) Amount of dynamic shared
+                                          // memory to allocate for a workgroup.
           args(%arg0 : f32,               // (Optional) Kernel arguments.
                %arg1 : memref<?xf32, 1>)
     }
@@ -370,7 +378,8 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
 
   let builders = [
     OpBuilder<(ins "GPUFuncOp":$kernelFunc, "KernelDim3":$gridSize,
-      "KernelDim3":$blockSize, "ValueRange":$kernelOperands)>
+      "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
+      "ValueRange":$kernelOperands)>
   ];
 
   let extraClassDeclaration = [{
@@ -411,24 +420,30 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
       $kernel
       `blocks` `in` ` ` `(`$gridSizeX`,` $gridSizeY`,` $gridSizeZ`)`
       `threads` `in` ` ` `(`$blockSizeX`,` $blockSizeY`,` $blockSizeZ`)`
-      custom<LaunchFuncOperands>($operands, type($operands))
-      attr-dict
+      (`dynamic_shared_memory_size` $dynamicSharedMemorySize^)?
+      custom<LaunchFuncOperands>($operands, type($operands)) attr-dict
   }];
 }
 
 def GPU_LaunchOp : GPU_Op<"launch">,
     Arguments<(ins Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
-               Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ)>,
+               Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ,
+               Optional<I32>:$dynamicSharedMemorySize)>,
     Results<(outs)> {
   let summary = "GPU kernel launch operation";
 
   let description = [{
     Launch a kernel on the specified grid of thread blocks. The body of the
     kernel is defined by the single region that this operation contains. The
-    operation takes six operands, with first three operands being grid sizes
-    along x,y,z dimensions and the following three arguments being block sizes
-    along x,y,z dimension. When a lower-dimensional kernel is required,
-    unused sizes must be explicitly set to `1`.
+    operation takes six operands followed by an optional operand: the first
+    three operands are grid sizes along the x,y,z dimensions and the following
+    three are block sizes along the x,y,z dimensions. The last operand is
+    optional and corresponds to the amount of dynamic shared memory a kernel's
+    workgroup should be allocated; when this operand is not present, a zero size
+    is assumed.
+
+    When a lower-dimensional kernel is required, unused sizes must
+    be explicitly set to `1`.
 
     The body region has _twelve_ arguments, grouped as follows:
 
@@ -442,7 +457,8 @@ def GPU_LaunchOp : GPU_Op<"launch">,
     ```
     operation ::= `gpu.launch` `block` `(` ssa-id-list `)` `in` ssa-reassignment
                              `threads` `(` ssa-id-list `)` `in` ssa-reassignment
-                               region attr-dict?
+                             (dynamic_shared_memory_size ssa-use)?
+                             region attr-dict?
     ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
     ```
 
@@ -490,7 +506,8 @@ def GPU_LaunchOp : GPU_Op<"launch">,
   let builders = [
     OpBuilder<(ins "Value":$gridSizeX, "Value":$gridSizeY,
       "Value":$gridSizeZ, "Value":$blockSizeX, "Value":$blockSizeY,
-      "Value":$blockSizeZ)>
+      "Value":$blockSizeZ,
+      CArg<"Value", "nullptr">:$dynamic_shared_memory_size)>
   ];
 
   let extraClassDeclaration = [{
@@ -510,6 +527,9 @@ def GPU_LaunchOp : GPU_Op<"launch">,
 
     static StringRef getBlocksKeyword() { return "blocks"; }
     static StringRef getThreadsKeyword() { return "threads"; }
+    static StringRef getDynamicSharedMemorySizeKeyword() {
+      return "dynamic_shared_memory_size";
+    }
 
     /// The number of launch configuration operands, placed at the leading
     /// positions of the operand list.

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 40a4463f7dbe6..ff4ae51c4bde6 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -745,13 +745,15 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
   // Create array of pointers to kernel arguments.
   auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
   auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
-  launchKernelCallBuilder.create(loc, rewriter,
-                                 {function.getResult(0), adaptor.gridSizeX(),
-                                  adaptor.gridSizeY(), adaptor.gridSizeZ(),
-                                  adaptor.blockSizeX(), adaptor.blockSizeY(),
-                                  adaptor.blockSizeZ(),
-                                  /*sharedMemBytes=*/zero, stream, kernelParams,
-                                  /*extra=*/nullpointer});
+  Value dynamicSharedMemorySize = launchOp.dynamicSharedMemorySize()
+                                      ? launchOp.dynamicSharedMemorySize()
+                                      : zero;
+  launchKernelCallBuilder.create(
+      loc, rewriter,
+      {function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(),
+       adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(),
+       adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
+       /*extra=*/nullpointer});
 
   if (launchOp.asyncToken()) {
     // Async launch: make dependent ops use the same stream.

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7c1bf8e654b81..9a75bb55dda63 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -353,10 +353,13 @@ void gpu::addAsyncDependency(Operation *op, Value token) {
 
 void LaunchOp::build(OpBuilder &builder, OperationState &result,
                      Value gridSizeX, Value gridSizeY, Value gridSizeZ,
-                     Value blockSizeX, Value blockSizeY, Value blockSizeZ) {
+                     Value blockSizeX, Value blockSizeY, Value blockSizeZ,
+                     Value dynamicSharedMemorySize) {
   // Add grid and block sizes as op operands, followed by the data operands.
   result.addOperands(
       {gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
+  if (dynamicSharedMemorySize)
+    result.addOperands(dynamicSharedMemorySize);
 
   // Create a kernel body region with kNumConfigRegionAttributes + N arguments,
   // where the first kNumConfigRegionAttributes arguments have `index` type and
@@ -406,7 +409,8 @@ static LogicalResult verify(LaunchOp op) {
   // for block/thread identifiers and grid/block sizes.
   if (!op.body().empty()) {
     if (op.body().getNumArguments() !=
-        LaunchOp::kNumConfigOperands + op.getNumOperands())
+        LaunchOp::kNumConfigOperands + op.getNumOperands() -
+            (op.dynamicSharedMemorySize() ? 1 : 0))
       return op.emitOpError("unexpected number of region arguments");
   }
 
@@ -450,6 +454,9 @@ static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
   p << ' ' << op.getThreadsKeyword();
   printSizeAssignment(p, op.getBlockSize(), op.getBlockSizeOperandValues(),
                       op.getThreadIds());
+  if (op.dynamicSharedMemorySize())
+    p << ' ' << op.getDynamicSharedMemorySizeKeyword() << ' '
+      << op.dynamicSharedMemorySize();
 
   p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
   p.printOptionalAttrDict(op->getAttrs());
@@ -521,6 +528,15 @@ static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) {
                              result.operands))
     return failure();
 
+  OpAsmParser::OperandType dynamicSharedMemorySize;
+  if (!parser.parseOptionalKeyword(
+          LaunchOp::getDynamicSharedMemorySizeKeyword()))
+    if (parser.parseOperand(dynamicSharedMemorySize) ||
+        parser.resolveOperand(dynamicSharedMemorySize,
+                              parser.getBuilder().getI32Type(),
+                              result.operands))
+      return failure();
+
   // Introduce the body region and parse it. The region has
   // kNumConfigRegionAttributes arguments that correspond to
   // block/thread identifiers and grid/block sizes, all of the `index` type.
@@ -577,25 +593,30 @@ void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
 
 void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
                          GPUFuncOp kernelFunc, KernelDim3 gridSize,
-                         KernelDim3 blockSize, ValueRange kernelOperands) {
+                         KernelDim3 blockSize, Value dynamicSharedMemorySize,
+                         ValueRange kernelOperands) {
   // Add grid and block sizes as op operands, followed by the data operands.
   result.addOperands({gridSize.x, gridSize.y, gridSize.z, blockSize.x,
                       blockSize.y, blockSize.z});
+  if (dynamicSharedMemorySize)
+    result.addOperands(dynamicSharedMemorySize);
   result.addOperands(kernelOperands);
   auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
   auto kernelSymbol =
       SymbolRefAttr::get(kernelModule.getNameAttr(),
                          {SymbolRefAttr::get(kernelFunc.getNameAttr())});
   result.addAttribute(getKernelAttrName(), kernelSymbol);
-  SmallVector<int32_t, 8> segmentSizes(8, 1);
+  SmallVector<int32_t, 9> segmentSizes(9, 1);
   segmentSizes.front() = 0; // Initially no async dependencies.
+  segmentSizes[segmentSizes.size() - 2] = dynamicSharedMemorySize ? 1 : 0;
   segmentSizes.back() = static_cast<int32_t>(kernelOperands.size());
   result.addAttribute(getOperandSegmentSizeAttr(),
                       builder.getI32VectorAttr(segmentSizes));
 }
 
 unsigned LaunchFuncOp::getNumKernelOperands() {
-  return getNumOperands() - asyncDependencies().size() - kNumConfigOperands;
+  return getNumOperands() - asyncDependencies().size() - kNumConfigOperands -
+         (dynamicSharedMemorySize() ? 1 : 0);
 }
 
 StringAttr LaunchFuncOp::getKernelModuleName() {
@@ -605,7 +626,8 @@ StringAttr LaunchFuncOp::getKernelModuleName() {
 StringAttr LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); }
 
 Value LaunchFuncOp::getKernelOperand(unsigned i) {
-  return getOperand(asyncDependencies().size() + kNumConfigOperands + i);
+  return getOperand(asyncDependencies().size() + kNumConfigOperands +
+                    (dynamicSharedMemorySize() ? 1 : 0) + i);
 }
 
 KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {

diff  --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 35fce506b662b..178bcf8742848 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -215,9 +215,12 @@ static void convertToLaunchFuncOp(gpu::LaunchOp launchOp,
                                   gpu::GPUFuncOp kernelFunc,
                                   ValueRange operands) {
   OpBuilder builder(launchOp);
+  // The launch op has an optional dynamic shared memory size. If it doesn't
+  // exist, we use zero.
   builder.create<gpu::LaunchFuncOp>(
       launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
-      launchOp.getBlockSizeOperandValues(), operands);
+      launchOp.getBlockSizeOperandValues(), launchOp.dynamicSharedMemorySize(),
+      operands);
   launchOp.erase();
 }
 

diff  --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
index 634385cf1a645..de81424616fb4 100644
--- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
@@ -20,14 +20,17 @@ module attributes {gpu.container_module} {
   func @foo(%buffer: memref<?xf32>) {
     %c8 = constant 8 : index
     %c32 = constant 32 : i32
+    %c256 = constant 256 : i32
     gpu.launch_func @kernel_module::@kernel
         blocks in (%c8, %c8, %c8)
         threads in (%c8, %c8, %c8)
+        dynamic_shared_memory_size %c256
         args(%c32 : i32, %buffer : memref<?xf32>)
     return
   }
 
-  // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
+  // CHECK-DAG: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32
+  // CHECK-DAG: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
   // CHECK: [[ADDRESSOF:%.*]] = llvm.mlir.addressof @[[GLOBAL]]
   // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index)
   // CHECK: [[BINARY:%.*]] = llvm.getelementptr [[ADDRESSOF]]{{\[}}[[C0]], [[C0]]]
@@ -36,7 +39,6 @@ module attributes {gpu.container_module} {
   // CHECK: [[MODULE:%.*]] = llvm.call @mgpuModuleLoad([[BINARY]])
   // CHECK: [[FUNC:%.*]] = llvm.call @mgpuModuleGetFunction([[MODULE]], {{.*}})
 
-  // CHECK: [[C0_I32:%.*]] = llvm.mlir.constant(0 : i32)
   // CHECK: [[STREAM:%.*]] = llvm.call @mgpuStreamCreate
 
   // CHECK: [[NUM_PARAMS:%.*]] = llvm.mlir.constant(6 : i32) : i32
@@ -45,7 +47,7 @@ module attributes {gpu.container_module} {
   // CHECK: [[EXTRA_PARAMS:%.*]] = llvm.mlir.null : !llvm.ptr<ptr<i8>>
 
   // CHECK: llvm.call @mgpuLaunchKernel([[FUNC]], [[C8]], [[C8]], [[C8]],
-  // CHECK-SAME: [[C8]], [[C8]], [[C8]], [[C0_I32]], [[STREAM]],
+  // CHECK-SAME: [[C8]], [[C8]], [[C8]], [[C256]], [[STREAM]],
   // CHECK-SAME: [[PARAMS]], [[EXTRA_PARAMS]])
   // CHECK: llvm.call @mgpuStreamSynchronize
   // CHECK: llvm.call @mgpuStreamDestroy

diff  --git a/mlir/test/Conversion/GPUToSPIRV/builtins.mlir b/mlir/test/Conversion/GPUToSPIRV/builtins.mlir
index b7732539fc4c3..6700fe92aaa85 100644
--- a/mlir/test/Conversion/GPUToSPIRV/builtins.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/builtins.mlir
@@ -27,8 +27,10 @@ module attributes {gpu.container_module} {
 module attributes {gpu.container_module} {
   func @builtin() {
     %c0 = constant 1 : index
+    %c256 = constant 256 : i32
     gpu.launch_func @kernels::@builtin_workgroup_id_y
         blocks in (%c0, %c0, %c0) threads in (%c0, %c0, %c0)
+        dynamic_shared_memory_size %c256
     return
   }
 

diff  --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 3c7a57d099e0d..c984bfa9e0cb9 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s
 
 func @not_enough_sizes(%sz : index) {
-  // expected-error at +1 {{expected 6 operands, but found 5}}
+  // expected-error at +1 {{expected 6 or more operands, but found 5}}
   "gpu.launch"(%sz, %sz, %sz, %sz, %sz) ({
     gpu.return
   }) : (index, index, index, index, index) -> ()
@@ -56,7 +56,7 @@ module attributes {gpu.container_module} {
   func @launch_func_missing_callee_attribute(%sz : index) {
     // expected-error at +1 {{'gpu.launch_func' op requires attribute 'kernel'}}
     "gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz)
-        {operand_segment_sizes = dense<[0, 1, 1, 1, 1, 1, 1, 0]> : vector<8xi32>}
+        {operand_segment_sizes = dense<[0, 1, 1, 1, 1, 1, 1, 0, 0]> : vector<9xi32>}
         : (index, index, index, index, index, index) -> ()
     return
   }

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 2c4a13d96d6d9..1efe3dc544dd0 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -73,11 +73,14 @@ module attributes {gpu.container_module} {
     %1 = "op"() : () -> (memref<?xf32, 1>)
     // CHECK: %{{.*}} = constant 8
     %cst = constant 8 : index
+    %c0 = constant 0 : i32
     %t0 = gpu.wait async
 
     // CHECK: gpu.launch_func @kernels::@kernel_1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
     gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%0 : f32, %1 : memref<?xf32, 1>)
 
+    gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) dynamic_shared_memory_size %c0 args(%0 : f32, %1 : memref<?xf32, 1>)
+
     // CHECK: gpu.launch_func @kernels::@kernel_2 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}})
     gpu.launch_func @kernels::@kernel_2 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
 


        


More information about the Mlir-commits mailing list