[Mlir-commits] [mlir] [mlir][gpu] Introduce `gpu.dynamic.shared.memory` Op (PR #71516)
Guray Ozen
llvmlistbot at llvm.org
Tue Nov 7 03:06:23 PST 2023
https://github.com/grypp created https://github.com/llvm/llvm-project/pull/71516
While the `gpu.launch` Op allows setting the size via the `dynamic_shared_memory_size` argument, accessing the dynamic shared memory is very convoluted. This PR implements the proposed Op, `gpu.dynamic.shared.memory` that aims to simplify the utilization of dynamic shared memory.
RFC: https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/
**Proposal from RFC**
This PR `gpu.dynamic.shared.memory` Op to use dynamic shared memory feature efficiently. It is is a powerful feature that enables the allocation of shared memory at runtime with the kernel launch on the host. Afterwards, the memory can be accessed directly from the device. I believe similar story exists for AMDGPU.
**New Op Features**
- No more 0-Sized Global Symbol Generation: The lowering will hide 1st and 3rd steps.
- Simplified Shared Memory Access: No need for reinterpret_cast or subview. The offset argument will be sufficient.
- Compile-time Bound Check: The Op verifier checks dynamic_shared_memory_size < offset if they are compile-time constants.
- Runtime-time Bound Check (TODO): We can add `{dynamicBoundCheck}` attribute that checks dynamic_shared_memory_size < offset on the runtime. This is optional and definitely adds overhead, but it can be beneficial for debugging.
**Current way Using Dynamic Shared Memory with MLIR**
Let me illustrate the challenges of using dynamic shared memory in MLIR with an example below. The process involves several steps:
- memref.global 0-sized array NVPTX expects
- dynamic_shared_memory_size Set the size of dynamic shared memory
- memref.get_global Access the global symbol
- reinterpret_cast and subview Many OPs for pointer arithmetic
```
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @dynamicShmem : memref<0xf16, 3> { alignment = 16 }
func.func @main() {
// Step 2. Allocate shared memory
gpu.launch blocks(...) threads(...)
dynamic_shared_memory_size %c10000 {
// Step 3. Access the global object
%shmem = memref.get_global @dynamicShmem : memref<0xf16, 3>
// Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
%4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
%5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
%6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
%7 = memref.subview %6[0, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
%8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>
// Step.5 Use
"test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
"test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
gpu.terminator
}
```
Let’s write the program above with that:
```
func.func @main() {
gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
%i = arith.constant 18 : index
// Step 1: Obtain shared memory directly
%7 = gpu.dynamic.shared.memory [%i,0,0] : memref<64x64xf16, 3>
%i2 = arith.addi %i, %c1
%8 = gpu.dynamic.shared.memory [%i2,0,0] : memref<64x64xf16, 3>
// Step 2: Utilize the shared memory
"test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
"test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
}
}
```
>From b79bcbe78e9030c3aff552eb7a8884599ee61924 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 12:05:49 +0100
Subject: [PATCH] [mlir][gpu] Introduce `gpu.dynamic.shared.memory` Op
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
While the `gpu.launch` Op allows setting the size via the `dynamic_shared_memory_size` argument, accessing the dynamic shared memory is very convoluted. This PR implements the proposed Op, `gpu.dynamic.shared.memory` that aims to simplify the utilization of dynamic shared memory.
RFC: https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/
**Proposal from RFC**
This PR `gpu.dynamic.shared.memory` Op to use dynamic shared memory feature efficiently. It is is a powerful feature that enables the allocation of shared memory at runtime with the kernel launch on the host. Afterwards, the memory can be accessed directly from the device. I believe similar story exists for AMDGPU.
**New Op Features**
- No more 0-Sized Global Symbol Generation: The lowering will hide 1st and 3rd steps.
- Simplified Shared Memory Access: No need for reinterpret_cast or subview. The offset argument will be sufficient.
- Compile-time Bound Check: The Op verifier checks dynamic_shared_memory_size < offset if they are compile-time constants.
- Runtime-time Bound Check (TODO): We can add `{dynamicBoundCheck}` attribute that checks dynamic_shared_memory_size < offset on the runtime. This is optional and definitely adds overhead, but it can be beneficial for debugging.
**Current way Using Dynamic Shared Memory with MLIR**
Let me illustrate the challenges of using dynamic shared memory in MLIR with an example below. The process involves several steps:
- memref.global 0-sized array NVPTX expects
- dynamic_shared_memory_size Set the size of dynamic shared memory
- memref.get_global Access the global symbol
- reinterpret_cast and subview Many OPs for pointer arithmetic
```
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @dynamicShmem : memref<0xf16, 3> { alignment = 16 }
func.func @main() {
// Step 2. Allocate shared memory
gpu.launch blocks(...) threads(...)
dynamic_shared_memory_size %c10000 {
// Step 3. Access the global object
%shmem = memref.get_global @dynamicShmem : memref<0xf16, 3>
// Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
%4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
%5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
%6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
%7 = memref.subview %6[0, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
%8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>
// Step.5 Use
"test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
"test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
gpu.terminator
}
```
Let’s write the program above with that:
```
func.func @main() {
gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
%i = arith.constant 18 : index
// Step 1: Obtain shared memory directly
%7 = gpu.dynamic.shared.memory [%i,0,0] : memref<64x64xf16, 3>
%i2 = arith.addi %i, %c1
%8 = gpu.dynamic.shared.memory [%i2,0,0] : memref<64x64xf16, 3>
// Step 2: Utilize the shared memory
"test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
"test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
}
}
```
---
mlir/include/mlir/Dialect/GPU/IR/GPUBase.td | 10 ++
mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h | 13 ++
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 74 +++++++++-
.../include/mlir/Dialect/LLVMIR/NVVMDialect.h | 3 +
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 103 ++++++++++++++
.../lib/Conversion/GPUCommon/GPUOpsLowering.h | 22 +++
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 3 +
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 132 +++++++++++++++---
.../Dialect/GPU/dynamic-shared-memory.mlir | 64 +++++++++
mlir/test/Dialect/GPU/invalid.mlir | 49 +++++++
10 files changed, 452 insertions(+), 21 deletions(-)
create mode 100644 mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 755c82d8b75c9c0..057b507c394e80f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -52,6 +52,16 @@ def GPU_Dialect : Dialect {
/// Returns the numeric value used to identify the private memory address
/// space.
static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; }
+
+ /// Return true if the given MemRefType has an integer address
+ /// space that matches the workgroup memory address space or
+ /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ static bool hasWorkgroupMemoryAddressSpace(MemRefType type);
+
+ /// Return true if the given Attribute has an integer address
+ /// space that matches the workgroup memory address space or
+ /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ static bool isWorkgroupMemoryAddressSpace(Attribute memorySpace);
}];
let dependentDialects = ["arith::ArithDialect"];
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 14a1fac5fd255f3..286856324950eb7 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -17,6 +17,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -32,6 +33,18 @@
namespace mlir {
namespace gpu {
+/// GPU memory space identifiers.
+enum GPUMemorySpace {
+ /// Generic memory space identifier.
+ kGenericMemorySpace = 0,
+
+ /// Global memory space identifier.
+ kGlobalMemorySpace = 1,
+
+ /// Shared memory space identifier.
+ kSharedMemorySpace = 3
+};
+
/// Utility class for the GPU dialect to represent triples of `Value`s
/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
struct KernelDim3 {
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 6375d35f4311295..eac5b0096a3e10c 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -433,6 +433,74 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
let hasVerifier = 1;
}
+def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic.shared.memory",
+ [MemoryEffects<[MemWrite]>] > {
+ let summary = "Get the memref for dynamic shared memory";
+
+ let description = [{
+ This operation returns shared memory, also referred to as workgroup memory,
+ using given offsets.
+
+ It is possible to use both constants and SSA values as offsets.
+
+ If this operation is used within a `gpu.launch`, the verifier will make an
+ attempt to verify that the offsets fall within bounds by utilizing the
+ `dynamic_shared_memory_size` argument of `gpu.launch` when the values are
+ compile-time constants. Otherwise, the verifier does not perform offset
+ checks.
+
+ Examples:
+ ```mlir
+ // Constant value, offset = 32 * 64 * sizeof(f32) * 1
+ %0 = gpu.dynamic.shared.memory [1] : memref<32x64xf32, #gpu.address_space<workgroup>>
+
+ // Multi-dimensional constant values, offset = (32 * 64 * 1 + 8) * sizeof(f32)
+ %0 = gpu.dynamic.shared.memory [1, 0, 8] : memref<32x64xf32, #gpu.address_space<workgroup>>
+
+ // Multi-dimensional dynamic values, offset = (32 * 64 * %1) * sizeof(f32)
+ %0 = gpu.dynamic.shared.memory [%1, 0, 0] : (index) -> memref<32x32xf32>
+
+ // Multi-dimensional mixed values, offset = (32 * 64 * %1 + 8) * sizeof(f32)
+ %0 = gpu.dynamic.shared.memory [%1, 0, 8] : (index) -> memref<32x32xf32>
+ ```
+ }];
+
+ let arguments = (ins
+ Variadic<Index>:$dynamic_offsets,
+ DenseI64ArrayAttr:$static_offsets
+ );
+
+ let results = (outs Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$resultMemref);
+
+ let assemblyFormat = [{
+ custom<DynamicIndexList>($dynamic_offsets, $static_offsets)
+ attr-dict
+ `:` type($resultMemref)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Type":$memref, "int64_t":$offsets)>,
+ OpBuilder<(ins "Type":$memref, "OpFoldResult":$offsets)>,
+ OpBuilder<(ins "Type":$memref, "ArrayRef<int64_t>":$offsets)>,
+ OpBuilder<(ins "Type":$memref, "ArrayRef<OpFoldResult>":$offsets)>,
+ ];
+
+ let extraClassDeclaration = [{
+ /// Return a vector with all the static and dynamic offsets indices.
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ OpBuilder builder(getContext());
+ return getMixedValues(getStaticOffsets(), getDynamicOffsets(), builder);
+ }
+
+ bool hasDynamicOffsets() {
+ auto dynPos = getDynamicOffsets();
+ return std::any_of(dynPos.begin(), dynPos.end(),
+ [](Value operand) { return operand != nullptr; });
+ }
+ }];
+ let hasVerifier = 1;
+}
+
def LaunchIndx : AnyTypeOf<[Index, I32, I64]>;
def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
@@ -587,7 +655,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ,
- Optional<I32>:$dynamicSharedMemorySize)>,
+ Optional<I32>:$dynamicSharedMemorySize,
+ OptionalAttr<I32Attr>:$guray)>,
Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
let summary = "GPU kernel launch operation";
@@ -693,7 +762,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
CArg<"Type", "nullptr">:$asyncTokenType,
CArg<"ValueRange", "{}">:$asyncDependencies,
CArg<"TypeRange", "{}">:$workgroupAttributions,
- CArg<"TypeRange", "{}">:$privateAttributions)>
+ CArg<"TypeRange", "{}">:$privateAttributions,
+ CArg<"IntegerAttr", "IntegerAttr()">:$guray)>
];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 8ff8f850a9c1858..08019e77ae6af8a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -27,6 +27,9 @@
namespace mlir {
namespace NVVM {
+// Shared memory has 128-bit alignment
+constexpr int kSharedMemoryAlignmentBit = 128;
+
/// NVVM memory space identifiers.
enum NVVMMemorySpace {
/// Global memory space identifier.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6d2585aa30ab4c5..c8f809ee88c54d7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
#include "GPUOpsLowering.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -554,6 +555,108 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
}
+/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
+/// or uses existing symbol.
+LLVM::GlobalOp getDynamicSharedMemorySymbol(
+ ConversionPatternRewriter &rewriter, gpu::DynamicSharedMemoryOp op,
+ const LLVMTypeConverter *typeConverter, unsigned alignmentBit) {
+ std::optional<LLVM::GlobalOp> existingGlobalOp;
+
+ MemRefType memrefType = op.getResultMemref().getType();
+ assert(memrefType && memrefType.hasStaticShape() &&
+ "expected static shaped memref type");
+
+ LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+ assert(funcOp && "cannot find llvm.func op");
+
+ gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
+ assert(moduleOp && "cannot find gpu.module op");
+
+ // Use already generated global op if it exists
+ int index = 0;
+ std::string prefix = llvm::formatv("__shmem_{0}", funcOp.getSymName());
+ moduleOp->walk([&](LLVM::GlobalOp globalOp) {
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+ if (arrayType.getNumElements() == 0) {
+ existingGlobalOp = globalOp;
+ return WalkResult::interrupt();
+ }
+ }
+ if (globalOp.getSymName().startswith(prefix))
+ index++;
+ return WalkResult::advance();
+ });
+ if (existingGlobalOp.has_value())
+ return existingGlobalOp.value();
+
+ // Generate a new global op
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(&moduleOp.front());
+
+ auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
+ typeConverter->convertType(memrefType.getElementType()), 0);
+ std::string name = std::string(llvm::formatv("{0}_{1}", prefix, index));
+ // TODO: better alignment calculation
+ uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+ return rewriter.create<LLVM::GlobalOp>(
+ funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
+ LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignmentByte,
+ mlir::gpu::GPUMemorySpace::kSharedMemorySpace);
+}
+
+LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
+ gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ MemRefType memrefType = op.getResultMemref().getType();
+ auto elementType = typeConverter->convertType(memrefType.getElementType());
+ assert(memrefType && "memref is not valid");
+
+ // Step 1: Generate a global symbol or existing for the dynamic shared
+ // memory
+ LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
+ rewriter, op, getTypeConverter(), alignmentBit);
+ assert(shmemOp && "cannot find module op or failed generating global op");
+
+ // Step 2. Get address of the global symbol
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(op);
+ auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
+ Type baseType = basePtr->getResultTypes().front();
+
+ // Step 3. Fill mixed dynamic and static offsets
+ SmallVector<LLVM::GEPArg> gepArgs;
+ for (auto [idx, value] : llvm::enumerate(op.getStaticOffsets())) {
+ if (ShapedType::isDynamic(value))
+ gepArgs.push_back(LLVM::GEPArg(adaptor.getDynamicOffsets()[idx]));
+ else
+ gepArgs.push_back(LLVM::GEPArg(value));
+ }
+
+ // Step 4. Generate GEP using offsets
+ Type gepResultType = elementType;
+ if (memrefType.hasStaticShape()) {
+ for (int64_t numElem : llvm::reverse(memrefType.getShape())) {
+ gepResultType = LLVM::LLVMArrayType::get(gepResultType, numElem);
+ }
+ }
+ Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, gepResultType,
+ basePtr, gepArgs);
+
+ // Step 5. Create a memref descriptor
+ SmallVector<Value> shape, strides;
+ Value sizeBytes;
+ getMemRefDescriptorSizes(loc, memrefType, {}, rewriter, shape, strides,
+ sizeBytes);
+
+ auto memRefDescriptor = this->createMemRefDescriptor(
+ loc, memrefType, shmemPtr, shmemPtr, shape, strides, rewriter);
+
+ // Step 6. Replace the op with memref descriptor
+ rewriter.replaceOp(op, {memRefDescriptor});
+ return success();
+}
+
void mlir::populateGpuMemorySpaceAttributeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
typeConverter.addTypeAttributeConversion(
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index bd90286494d8035..1805e2b06f40481 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,28 @@
namespace mlir {
+/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
+/// create a 0-sized global array symbol similar as LLVM expects. Subsequently,
+/// it computes the offset using 'getelementptr' with its offset arguments.
+/// Finally, it constructs a memref descriptor with these values and return it.
+struct GPUDynamicSharedMemoryOpLowering
+ : public ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp> {
+ using ConvertOpToLLVMPattern<
+ gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
+ GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
+ unsigned alignmentBit = 0)
+ : ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
+ alignmentBit(alignmentBit) {}
+
+ LogicalResult
+ matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+
+private:
+ // Alignment bit
+ unsigned alignmentBit;
+};
+
struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
GPUFuncOpLowering(const LLVMTypeConverter &converter,
unsigned allocaAddrSpace, unsigned workgroupAddrSpace,
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 935e3d2a4095003..86a77f557cb9579 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -325,6 +325,9 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
converter);
+ patterns.add<GPUDynamicSharedMemoryOpLowering>(
+ converter, NVVM::kSharedMemoryAlignmentBit);
+
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
// memory space and does not support `alloca`s with addrspace(5).
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5eb2cadc884e151..cc86ad74ea3ca7d 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -164,17 +164,20 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
// GPUDialect
//===----------------------------------------------------------------------===//
-/// GPU memory space identifiers.
-enum GPUMemorySpace {
- /// Generic memory space identifier.
- kGenericMemorySpace = 0,
-
- /// Global memory space identifier.
- kGlobalMemorySpace = 1,
+bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
+ if (!memorySpace)
+ return false;
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+ return intAttr.getInt() == GPUMemorySpace::kSharedMemorySpace;
+ if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ return gpuAttr.getValue() == getWorkgroupAddressSpace();
+ return false;
+}
- /// Shared memory space identifier.
- kSharedMemorySpace = 3
-};
+bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
+ Attribute memorySpace = type.getMemorySpace();
+ return isWorkgroupMemoryAddressSpace(memorySpace);
+}
bool GPUDialect::isKernel(Operation *op) {
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
@@ -612,13 +615,16 @@ void gpu::addAsyncDependency(Operation *op, Value token) {
// LaunchOp
//===----------------------------------------------------------------------===//
+static constexpr int64_t kDynamic = std::numeric_limits<int32_t>::min();
+
void LaunchOp::build(OpBuilder &builder, OperationState &result,
Value gridSizeX, Value gridSizeY, Value gridSizeZ,
Value getBlockSizeX, Value getBlockSizeY,
Value getBlockSizeZ, Value dynamicSharedMemorySize,
Type asyncTokenType, ValueRange asyncDependencies,
TypeRange workgroupAttributions,
- TypeRange privateAttributions) {
+ TypeRange privateAttributions,
+ IntegerAttr dynamicSharedMemorySizeAttr) {
// Add a WorkGroup attribution attribute. This attribute is required to
// identify private attributions in the list of block argguments.
result.addAttribute(getNumWorkgroupAttributionsAttrName(),
@@ -634,7 +640,10 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
getBlockSizeY, getBlockSizeZ});
if (dynamicSharedMemorySize)
result.addOperands(dynamicSharedMemorySize);
-
+ if (!dynamicSharedMemorySizeAttr)
+ dynamicSharedMemorySizeAttr = builder.getI32IntegerAttr(kDynamic);
+
+ result.addAttribute("guray", dynamicSharedMemorySizeAttr);
// Create a kernel body region with kNumConfigRegionAttributes + N memory
// attributions, where the first kNumConfigRegionAttributes arguments have
// `index` type and the rest have the same types as the data operands.
@@ -759,6 +768,11 @@ void LaunchOp::print(OpAsmPrinter &p) {
if (getDynamicSharedMemorySize())
p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
<< getDynamicSharedMemorySize();
+ else if(getGurayAttr()) {
+ p << ' ' << getDynamicSharedMemorySizeKeyword() << ' ' << getGurayAttr().getInt();
+
+ }
+
printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
@@ -768,7 +782,8 @@ void LaunchOp::print(OpAsmPrinter &p) {
p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
LaunchOp::getOperandSegmentSizeAttr(),
- getNumWorkgroupAttributionsAttrName()});
+ getNumWorkgroupAttributionsAttrName(),
+ "guray"});
}
// Parse the size assignment blocks for blocks and threads. These have the form
@@ -854,12 +869,19 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
bool hasDynamicSharedMemorySize = false;
if (!parser.parseOptionalKeyword(
LaunchOp::getDynamicSharedMemorySizeKeyword())) {
- hasDynamicSharedMemorySize = true;
- if (parser.parseOperand(dynamicSharedMemorySize) ||
- parser.resolveOperand(dynamicSharedMemorySize,
- parser.getBuilder().getI32Type(),
- result.operands))
- return failure();
+ IntegerAttr shmemAttr;
+ OptionalParseResult shmemAttrResult =
+ parser.parseOptionalAttribute(shmemAttr, parser.getBuilder().getI32Type());
+ if(!shmemAttrResult.has_value()) {
+ hasDynamicSharedMemorySize = true;
+ shmemAttr = parser.getBuilder().getI32IntegerAttr(kDynamic);
+ if (parser.parseOperand(dynamicSharedMemorySize) ||
+ parser.resolveOperand(dynamicSharedMemorySize,
+ parser.getBuilder().getI32Type(),
+ result.operands))
+ return failure();
+ }
+ result.addAttribute("guray", shmemAttr);
}
// Create the region arguments, it has kNumConfigRegionAttributes arguments
@@ -2024,6 +2046,78 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// DynamicSharedMemoryOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult gpu::DynamicSharedMemoryOp::verify() {
+ MemRefType memrefType = getResultMemref().getType();
+ unsigned long rank = memrefType.getRank();
+ unsigned long offset = getStaticOffsets().size();
+
+ // Number of offset can be one dimension larger the memref rank
+ if ((offset + 1) < rank) {
+ return emitOpError("Number of offset must match the rank of the memref");
+ }
+
+ // Check address space
+ if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
+ return emitOpError() << "Address space must be "
+ << gpu::AddressSpaceAttr::getMnemonic() << "<"
+ << stringifyEnum(gpu::AddressSpace::Workgroup)
+ << "> or " << int(GPUMemorySpace::kSharedMemorySpace)
+ << ".";
+ }
+
+ if (memrefType.hasStaticShape()) {
+ std::optional<int64_t> shmemUpperBound = std::nullopt;
+
+ // Calculate upper bound of the dynamic shared memory size.
+ if (auto launchOp =
+ this->getOperation()->getParentOfType<gpu::LaunchOp>()) {
+ if (auto constOp = launchOp.getDynamicSharedMemorySize()
+ .getDefiningOp<mlir::arith::ConstantOp>()) {
+ if (auto dynamicSharedMemSizeAttr =
+ constOp.getValueAttr().dyn_cast<IntegerAttr>()) {
+ shmemUpperBound = dynamicSharedMemSizeAttr.getInt();
+ }
+ }
+ }
+
+ // Check upper bound of memory space and compare it with allocated one
+ if (shmemUpperBound.has_value()) {
+ int64_t requestedUpperBound = 0;
+ // Calculate upper bound when offsets are visible
+ if (!this->hasDynamicOffsets()) {
+ for (auto [idx, offset] : llvm::enumerate(getStaticOffsets())) {
+ if (ShapedType::isDynamic(offset))
+ continue;
+ int64_t memrefSize = memrefType.getElementTypeBitWidth() / 8;
+ for (int64_t j = (rank - 1); j >= int64_t(idx); --j)
+ memrefSize *= memrefType.getShape()[j];
+ requestedUpperBound += memrefSize * offset;
+ }
+ }
+ // Calculate at least the size of memref when offsets are not visible
+ else {
+ requestedUpperBound = memrefType.getNumElements() *
+ memrefType.getElementTypeBitWidth() / 8;
+ }
+ if (requestedUpperBound > shmemUpperBound.value()) {
+ return emitOpError()
+ << "gpu.launch allocates a " << shmemUpperBound.value()
+ << " bytes of dynamic shared "
+ "memory, but the Op's access upper bound requires "
+ << requestedUpperBound
+ << " bytes, "
+ "which exceeds the currently allocated memory limit.";
+ }
+ }
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// GPU target options
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
new file mode 100644
index 000000000000000..cd2108b975a732b
--- /dev/null
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
+
+gpu.module @modules {
+ // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x f32>
+
+ // CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
+ // CHECK-SAME: %[[arg0:.+]]: i64)
+ gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %0 = gpu.dynamic.shared.memory [1, 0, 0] : memref<32x64xf32, #gpu.address_space<workgroup>>
+ %1 = gpu.dynamic.shared.memory [1, 0, 0] : memref<32x32xf32, 3>
+ %2 = gpu.dynamic.shared.memory [4, 234] : memref<32x32xf32, #gpu.address_space<workgroup>>
+ %3 = gpu.dynamic.shared.memory [%c100, 4] : memref<32x32xf32, #gpu.address_space<workgroup>>
+ %4 = gpu.dynamic.shared.memory [%c100, %d] : memref<128x64xf16, #gpu.address_space<workgroup>>
+ %5 = gpu.dynamic.shared.memory [32, 0, 0] : memref<32x8xf32, #gpu.address_space<workgroup>>
+ "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+ "test.use.shared.memory"(%1) : (memref<32x32xf32, 3>) -> ()
+ "test.use.shared.memory"(%2) : (memref<32x32xf32, #gpu.address_space<workgroup>>) -> ()
+ "test.use.shared.memory"(%3) : (memref<32x32xf32, #gpu.address_space<workgroup>>) -> ()
+ "test.use.shared.memory"(%4) : (memref<128x64xf16, #gpu.address_space<workgroup>>) -> ()
+ "test.use.shared.memory"(%5) : (memref<32x8xf32, #gpu.address_space<workgroup>>) -> ()
+
+ // CHECK: %[[S2:.+]] = llvm.mlir.constant(0 : index) : i64
+ // CHECK: %[[S3:.+]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: %[[S4:.+]] = llvm.mlir.constant(64 : index) : i64
+ // CHECK: %[[S5:.+]] = llvm.mlir.constant(32 : index) : i64
+ // CHECK: %[[S6:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_0 : !llvm.ptr<3>
+ // CHECK: %[[S7:.+]] = llvm.getelementptr %[[S6]][1, 0, 0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<32 x array<64 x f32>>
+ // CHECK: %[[S8:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[S9:.+]] = llvm.insertvalue %[[S7]], %[[S8]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[S10:.+]] = llvm.insertvalue %[[S7]], %[[S9]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[S11:.+]] = llvm.insertvalue %[[S2]], %[[S10]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[S12:.+]] = llvm.insertvalue %[[S5]], %[[S11]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[S13:.+]] = llvm.insertvalue %[[S4]], %[[S12]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[S14:.+]] = llvm.insertvalue %[[S4]], %[[S13]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[S15:.+]] = llvm.insertvalue %[[S3]], %[[S14]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[S16:.+]] = builtin.unrealized_conversion_cast %[[S15]] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+
+ // CHECK: %[[S17:.+]] = llvm.getelementptr %[[S6]][1, 0, 0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<32 x array<32 x f32>>
+ // CHECK: %[[S25:.+]] = builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x32xf32, 3>
+
+ // CHECK: %[[S26:.+]] = llvm.getelementptr %[[S6]][4, 234] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<32 x array<32 x f32>>
+ // CHECK: %[[S34:.+]] = builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x32xf32, #gpu.address_space<workgroup>>
+
+ // CHECK: %[[S35:.+]] = llvm.getelementptr %[[S6]][100, 4] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<32 x array<32 x f32>>
+ // CHECK: %[[S43:.+]] = builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x32xf32, #gpu.address_space<workgroup>>
+
+ // CHECK: %[[S44:.+]] = llvm.getelementptr %[[S6]][100, %[[arg0]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, !llvm.array<128 x array<64 x f16>>
+ // CHECK: %[[S52:.+]] = builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<128x64xf16, #gpu.address_space<workgroup>>
+
+ // CHECK: %[[S53:.+]] = llvm.getelementptr %[[S6]][32, 0, 0] : (!llvm.ptr<3>) -> !llvm.ptr<3>, !llvm.array<32 x array<8 x f32>>
+ // CHECK: %[[S61:.+]] = builtin.unrealized_conversion_cast %{{.*}} : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x8xf32, #gpu.address_space<workgroup>>
+
+ // CHECK: "test.use.shared.memory"(%[[S16]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+ // CHECK: "test.use.shared.memory"(%[[S25]]) : (memref<32x32xf32, 3>) -> ()
+ // CHECK: "test.use.shared.memory"(%[[S34]]) : (memref<32x32xf32, #gpu.address_space<workgroup>>) -> ()
+ // CHECK: "test.use.shared.memory"(%[[S43]]) : (memref<32x32xf32, #gpu.address_space<workgroup>>) -> ()
+ // CHECK: "test.use.shared.memory"(%[[S52]]) : (memref<128x64xf16, #gpu.address_space<workgroup>>) -> ()
+ // CHECK: "test.use.shared.memory"(%[[S61]]) : (memref<32x8xf32, #gpu.address_space<workgroup>>) -> ()
+
+ gpu.return
+ }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index c8c0b7d24bc3ab2..c35c286fa165999 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -640,3 +640,52 @@ module {
// expected-error @+1 {{'gpu.binary' op attribute 'offloadingHandler' failed to satisfy constraint: any attribute with the `OffloadingTranslationAttrTrait` trait.}}
gpu.binary @binary <1> [#gpu.object<#nvvm.target, "">]
}
+
+// -----
+
+func.func @main() {
+ %shmemSize = arith.constant 10000 : i32
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ dynamic_shared_memory_size %shmemSize
+ {
+ // expected-error @+1 {{'gpu.dynamic.shared.memory' op gpu.launch allocates a 10000 bytes of dynamic shared memory, but the Op's access upper bound requires 8192000 bytes, which exceeds the currently allocated memory limit}}
+ %0 = gpu.dynamic.shared.memory [1000, 0, 0] : memref<32x64xf32, #gpu.address_space<workgroup>>
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @main() {
+ %shmemSize = arith.constant 8192 : i32
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ dynamic_shared_memory_size %shmemSize
+ {
+ // expected-error @+1 {{'gpu.dynamic.shared.memory' op gpu.launch allocates a 8192 bytes of dynamic shared memory, but the Op's access upper bound requires 8196 bytes, which exceeds the currently allocated memory limit}}
+ %0 = gpu.dynamic.shared.memory [1, 0, 1] : memref<32x64xf32, #gpu.address_space<workgroup>>
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @main(%arg0 : index) {
+ %shmemSize = arith.constant 8192 : i32
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ dynamic_shared_memory_size %shmemSize
+ {
+ // expected-error @+1 {{'gpu.dynamic.shared.memory' op gpu.launch allocates a 8192 bytes of dynamic shared memory, but the Op's access upper bound requires 262144 bytes, which exceeds the currently allocated memory limit}}
+ %0 = gpu.dynamic.shared.memory [%arg0, 0, 0] : memref<128x512xf32, #gpu.address_space<workgroup>>
+ gpu.terminator
+ }
+ return
+}
+
More information about the Mlir-commits
mailing list