[Mlir-commits] [mlir] 08b63db - [MLIR][GPU] Add GPU launch op support for dynamic shared memory
Uday Bondhugula
llvmlistbot at llvm.org
Fri Oct 1 04:21:13 PDT 2021
Author: Uday Bondhugula
Date: 2021-10-01T16:46:07+05:30
New Revision: 08b63db8bb3ea847543351e1268be31ea327ad6b
URL: https://github.com/llvm/llvm-project/commit/08b63db8bb3ea847543351e1268be31ea327ad6b
DIFF: https://github.com/llvm/llvm-project/commit/08b63db8bb3ea847543351e1268be31ea327ad6b.diff
LOG: [MLIR][GPU] Add GPU launch op support for dynamic shared memory
Add support for dynamic shared memory for GPU launch ops: add an
optional operand to gpu.launch and gpu.launch_func ops to specify the
amount of "dynamic" shared memory to use. Update lowerings to connect
this operand to the GPU runtime.
Differential Revision: https://reviews.llvm.org/D110800
Added:
Modified:
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
mlir/test/Conversion/GPUToSPIRV/builtins.mlir
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/GPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td
index f3ce173108590..97359bce201ef 100644
--- a/mlir/include/mlir/Dialect/GPU/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td
@@ -289,6 +289,7 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
SymbolRefAttr:$kernel,
Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ,
+ Optional<I32>:$dynamicSharedMemorySize,
Variadic<AnyType>:$operands)>,
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
let summary = "Launches a function as a GPU kernel";
@@ -317,7 +318,12 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
dimensions as arguments. When a lower-dimensional kernel is required,
unused sizes must be explicitly set to `1`.
- The remaining operands are passed as arguments to the kernel function.
+ The remaining operands are optional. The first optional operand corresponds
+ to the amount of dynamic shared memory a kernel's workgroup should be
+ allocated; when this operand is not present, a zero size is assumed.
+
+ The remaining operands if present are passed as arguments to the kernel
+ function.
Example:
@@ -360,6 +366,8 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
@kernels::@kernel_1 // Kernel function.
blocks in (%cst, %cst, %cst) // Grid size.
threads in (%cst, %cst, %cst) // Block size.
+ dynamic_shared_memory_size %s // (Optional) Amount of dynamic shared
+ // memory to allocate for a workgroup.
args(%arg0 : f32, // (Optional) Kernel arguments.
%arg1 : memref<?xf32, 1>)
}
@@ -370,7 +378,8 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
let builders = [
OpBuilder<(ins "GPUFuncOp":$kernelFunc, "KernelDim3":$gridSize,
- "KernelDim3":$blockSize, "ValueRange":$kernelOperands)>
+ "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
+ "ValueRange":$kernelOperands)>
];
let extraClassDeclaration = [{
@@ -411,24 +420,30 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
$kernel
`blocks` `in` ` ` `(`$gridSizeX`,` $gridSizeY`,` $gridSizeZ`)`
`threads` `in` ` ` `(`$blockSizeX`,` $blockSizeY`,` $blockSizeZ`)`
- custom<LaunchFuncOperands>($operands, type($operands))
- attr-dict
+ (`dynamic_shared_memory_size` $dynamicSharedMemorySize^)?
+ custom<LaunchFuncOperands>($operands, type($operands)) attr-dict
}];
}
def GPU_LaunchOp : GPU_Op<"launch">,
Arguments<(ins Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
- Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ)>,
+ Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ,
+ Optional<I32>:$dynamicSharedMemorySize)>,
Results<(outs)> {
let summary = "GPU kernel launch operation";
let description = [{
Launch a kernel on the specified grid of thread blocks. The body of the
kernel is defined by the single region that this operation contains. The
- operation takes six operands, with first three operands being grid sizes
- along x,y,z dimensions and the following three arguments being block sizes
- along x,y,z dimension. When a lower-dimensional kernel is required,
- unused sizes must be explicitly set to `1`.
+ operation takes six operands followed by an optional operand: the first
+ three operands are grid sizes along the x,y,z dimensions and the following
+ three are block sizes along the x,y,z dimensions. The last operand is
+ optional and corresponds to the amount of dynamic shared memory a kernel's
+ workgroup should be allocated; when this operand is not present, a zero size
+ is assumed.
+
+ When a lower-dimensional kernel is required, unused sizes must
+ be explicitly set to `1`.
The body region has _twelve_ arguments, grouped as follows:
@@ -442,7 +457,8 @@ def GPU_LaunchOp : GPU_Op<"launch">,
```
operation ::= `gpu.launch` `block` `(` ssa-id-list `)` `in` ssa-reassignment
`threads` `(` ssa-id-list `)` `in` ssa-reassignment
- region attr-dict?
+ (dynamic_shared_memory_size ssa-use)?
+ region attr-dict?
ssa-reassignment ::= `(` ssa-id `=` ssa-use (`,` ssa-id `=` ssa-use)* `)`
```
@@ -490,7 +506,8 @@ def GPU_LaunchOp : GPU_Op<"launch">,
let builders = [
OpBuilder<(ins "Value":$gridSizeX, "Value":$gridSizeY,
"Value":$gridSizeZ, "Value":$blockSizeX, "Value":$blockSizeY,
- "Value":$blockSizeZ)>
+ "Value":$blockSizeZ,
+ CArg<"Value", "nullptr">:$dynamic_shared_memory_size)>
];
let extraClassDeclaration = [{
@@ -510,6 +527,9 @@ def GPU_LaunchOp : GPU_Op<"launch">,
static StringRef getBlocksKeyword() { return "blocks"; }
static StringRef getThreadsKeyword() { return "threads"; }
+ static StringRef getDynamicSharedMemorySizeKeyword() {
+ return "dynamic_shared_memory_size";
+ }
/// The number of launch configuration operands, placed at the leading
/// positions of the operand list.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 40a4463f7dbe6..ff4ae51c4bde6 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -745,13 +745,15 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
// Create array of pointers to kernel arguments.
auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
- launchKernelCallBuilder.create(loc, rewriter,
- {function.getResult(0), adaptor.gridSizeX(),
- adaptor.gridSizeY(), adaptor.gridSizeZ(),
- adaptor.blockSizeX(), adaptor.blockSizeY(),
- adaptor.blockSizeZ(),
- /*sharedMemBytes=*/zero, stream, kernelParams,
- /*extra=*/nullpointer});
+ Value dynamicSharedMemorySize = launchOp.dynamicSharedMemorySize()
+ ? launchOp.dynamicSharedMemorySize()
+ : zero;
+ launchKernelCallBuilder.create(
+ loc, rewriter,
+ {function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(),
+ adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(),
+ adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
+ /*extra=*/nullpointer});
if (launchOp.asyncToken()) {
// Async launch: make dependent ops use the same stream.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7c1bf8e654b81..9a75bb55dda63 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -353,10 +353,13 @@ void gpu::addAsyncDependency(Operation *op, Value token) {
void LaunchOp::build(OpBuilder &builder, OperationState &result,
Value gridSizeX, Value gridSizeY, Value gridSizeZ,
- Value blockSizeX, Value blockSizeY, Value blockSizeZ) {
+ Value blockSizeX, Value blockSizeY, Value blockSizeZ,
+ Value dynamicSharedMemorySize) {
// Add grid and block sizes as op operands, followed by the data operands.
result.addOperands(
{gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ});
+ if (dynamicSharedMemorySize)
+ result.addOperands(dynamicSharedMemorySize);
// Create a kernel body region with kNumConfigRegionAttributes + N arguments,
// where the first kNumConfigRegionAttributes arguments have `index` type and
@@ -406,7 +409,8 @@ static LogicalResult verify(LaunchOp op) {
// for block/thread identifiers and grid/block sizes.
if (!op.body().empty()) {
if (op.body().getNumArguments() !=
- LaunchOp::kNumConfigOperands + op.getNumOperands())
+ LaunchOp::kNumConfigOperands + op.getNumOperands() -
+ (op.dynamicSharedMemorySize() ? 1 : 0))
return op.emitOpError("unexpected number of region arguments");
}
@@ -450,6 +454,9 @@ static void printLaunchOp(OpAsmPrinter &p, LaunchOp op) {
p << ' ' << op.getThreadsKeyword();
printSizeAssignment(p, op.getBlockSize(), op.getBlockSizeOperandValues(),
op.getThreadIds());
+ if (op.dynamicSharedMemorySize())
+ p << ' ' << op.getDynamicSharedMemorySizeKeyword() << ' '
+ << op.dynamicSharedMemorySize();
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op->getAttrs());
@@ -521,6 +528,15 @@ static ParseResult parseLaunchOp(OpAsmParser &parser, OperationState &result) {
result.operands))
return failure();
+ OpAsmParser::OperandType dynamicSharedMemorySize;
+ if (!parser.parseOptionalKeyword(
+ LaunchOp::getDynamicSharedMemorySizeKeyword()))
+ if (parser.parseOperand(dynamicSharedMemorySize) ||
+ parser.resolveOperand(dynamicSharedMemorySize,
+ parser.getBuilder().getI32Type(),
+ result.operands))
+ return failure();
+
// Introduce the body region and parse it. The region has
// kNumConfigRegionAttributes arguments that correspond to
// block/thread identifiers and grid/block sizes, all of the `index` type.
@@ -577,25 +593,30 @@ void LaunchOp::getCanonicalizationPatterns(RewritePatternSet &rewrites,
void LaunchFuncOp::build(OpBuilder &builder, OperationState &result,
GPUFuncOp kernelFunc, KernelDim3 gridSize,
- KernelDim3 blockSize, ValueRange kernelOperands) {
+ KernelDim3 blockSize, Value dynamicSharedMemorySize,
+ ValueRange kernelOperands) {
// Add grid and block sizes as op operands, followed by the data operands.
result.addOperands({gridSize.x, gridSize.y, gridSize.z, blockSize.x,
blockSize.y, blockSize.z});
+ if (dynamicSharedMemorySize)
+ result.addOperands(dynamicSharedMemorySize);
result.addOperands(kernelOperands);
auto kernelModule = kernelFunc->getParentOfType<GPUModuleOp>();
auto kernelSymbol =
SymbolRefAttr::get(kernelModule.getNameAttr(),
{SymbolRefAttr::get(kernelFunc.getNameAttr())});
result.addAttribute(getKernelAttrName(), kernelSymbol);
- SmallVector<int32_t, 8> segmentSizes(8, 1);
+ SmallVector<int32_t, 9> segmentSizes(9, 1);
segmentSizes.front() = 0; // Initially no async dependencies.
+ segmentSizes[segmentSizes.size() - 2] = dynamicSharedMemorySize ? 1 : 0;
segmentSizes.back() = static_cast<int32_t>(kernelOperands.size());
result.addAttribute(getOperandSegmentSizeAttr(),
builder.getI32VectorAttr(segmentSizes));
}
unsigned LaunchFuncOp::getNumKernelOperands() {
- return getNumOperands() - asyncDependencies().size() - kNumConfigOperands;
+ return getNumOperands() - asyncDependencies().size() - kNumConfigOperands -
+ (dynamicSharedMemorySize() ? 1 : 0);
}
StringAttr LaunchFuncOp::getKernelModuleName() {
@@ -605,7 +626,8 @@ StringAttr LaunchFuncOp::getKernelModuleName() {
StringAttr LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); }
Value LaunchFuncOp::getKernelOperand(unsigned i) {
- return getOperand(asyncDependencies().size() + kNumConfigOperands + i);
+ return getOperand(asyncDependencies().size() + kNumConfigOperands +
+ (dynamicSharedMemorySize() ? 1 : 0) + i);
}
KernelDim3 LaunchFuncOp::getGridSizeOperandValues() {
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index 35fce506b662b..178bcf8742848 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -215,9 +215,12 @@ static void convertToLaunchFuncOp(gpu::LaunchOp launchOp,
gpu::GPUFuncOp kernelFunc,
ValueRange operands) {
OpBuilder builder(launchOp);
+ // The launch op has an optional dynamic shared memory size. If it doesn't
+ // exist, we use zero.
builder.create<gpu::LaunchFuncOp>(
launchOp.getLoc(), kernelFunc, launchOp.getGridSizeOperandValues(),
- launchOp.getBlockSizeOperandValues(), operands);
+ launchOp.getBlockSizeOperandValues(), launchOp.dynamicSharedMemorySize(),
+ operands);
launchOp.erase();
}
diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
index 634385cf1a645..de81424616fb4 100644
--- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir
@@ -20,14 +20,17 @@ module attributes {gpu.container_module} {
func @foo(%buffer: memref<?xf32>) {
%c8 = constant 8 : index
%c32 = constant 32 : i32
+ %c256 = constant 256 : i32
gpu.launch_func @kernel_module::@kernel
blocks in (%c8, %c8, %c8)
threads in (%c8, %c8, %c8)
+ dynamic_shared_memory_size %c256
args(%c32 : i32, %buffer : memref<?xf32>)
return
}
- // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
+ // CHECK-DAG: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32
+ // CHECK-DAG: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64
// CHECK: [[ADDRESSOF:%.*]] = llvm.mlir.addressof @[[GLOBAL]]
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index)
// CHECK: [[BINARY:%.*]] = llvm.getelementptr [[ADDRESSOF]]{{\[}}[[C0]], [[C0]]]
@@ -36,7 +39,6 @@ module attributes {gpu.container_module} {
// CHECK: [[MODULE:%.*]] = llvm.call @mgpuModuleLoad([[BINARY]])
// CHECK: [[FUNC:%.*]] = llvm.call @mgpuModuleGetFunction([[MODULE]], {{.*}})
- // CHECK: [[C0_I32:%.*]] = llvm.mlir.constant(0 : i32)
// CHECK: [[STREAM:%.*]] = llvm.call @mgpuStreamCreate
// CHECK: [[NUM_PARAMS:%.*]] = llvm.mlir.constant(6 : i32) : i32
@@ -45,7 +47,7 @@ module attributes {gpu.container_module} {
// CHECK: [[EXTRA_PARAMS:%.*]] = llvm.mlir.null : !llvm.ptr<ptr<i8>>
// CHECK: llvm.call @mgpuLaunchKernel([[FUNC]], [[C8]], [[C8]], [[C8]],
- // CHECK-SAME: [[C8]], [[C8]], [[C8]], [[C0_I32]], [[STREAM]],
+ // CHECK-SAME: [[C8]], [[C8]], [[C8]], [[C256]], [[STREAM]],
// CHECK-SAME: [[PARAMS]], [[EXTRA_PARAMS]])
// CHECK: llvm.call @mgpuStreamSynchronize
// CHECK: llvm.call @mgpuStreamDestroy
diff --git a/mlir/test/Conversion/GPUToSPIRV/builtins.mlir b/mlir/test/Conversion/GPUToSPIRV/builtins.mlir
index b7732539fc4c3..6700fe92aaa85 100644
--- a/mlir/test/Conversion/GPUToSPIRV/builtins.mlir
+++ b/mlir/test/Conversion/GPUToSPIRV/builtins.mlir
@@ -27,8 +27,10 @@ module attributes {gpu.container_module} {
module attributes {gpu.container_module} {
func @builtin() {
%c0 = constant 1 : index
+ %c256 = constant 256 : i32
gpu.launch_func @kernels::@builtin_workgroup_id_y
blocks in (%c0, %c0, %c0) threads in (%c0, %c0, %c0)
+ dynamic_shared_memory_size %c256
return
}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 3c7a57d099e0d..c984bfa9e0cb9 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
func @not_enough_sizes(%sz : index) {
- // expected-error at +1 {{expected 6 operands, but found 5}}
+ // expected-error at +1 {{expected 6 or more operands, but found 5}}
"gpu.launch"(%sz, %sz, %sz, %sz, %sz) ({
gpu.return
}) : (index, index, index, index, index) -> ()
@@ -56,7 +56,7 @@ module attributes {gpu.container_module} {
func @launch_func_missing_callee_attribute(%sz : index) {
// expected-error at +1 {{'gpu.launch_func' op requires attribute 'kernel'}}
"gpu.launch_func"(%sz, %sz, %sz, %sz, %sz, %sz)
- {operand_segment_sizes = dense<[0, 1, 1, 1, 1, 1, 1, 0]> : vector<8xi32>}
+ {operand_segment_sizes = dense<[0, 1, 1, 1, 1, 1, 1, 0, 0]> : vector<9xi32>}
: (index, index, index, index, index, index) -> ()
return
}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 2c4a13d96d6d9..1efe3dc544dd0 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -73,11 +73,14 @@ module attributes {gpu.container_module} {
%1 = "op"() : () -> (memref<?xf32, 1>)
// CHECK: %{{.*}} = constant 8
%cst = constant 8 : index
+ %c0 = constant 0 : i32
%t0 = gpu.wait async
// CHECK: gpu.launch_func @kernels::@kernel_1 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) args(%{{.*}} : f32, %{{.*}} : memref<?xf32, 1>)
gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%0 : f32, %1 : memref<?xf32, 1>)
+ gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) dynamic_shared_memory_size %c0 args(%0 : f32, %1 : memref<?xf32, 1>)
+
// CHECK: gpu.launch_func @kernels::@kernel_2 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}})
gpu.launch_func @kernels::@kernel_2 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
More information about the Mlir-commits
mailing list