[Mlir-commits] [mlir] ea84897 - [mlir][gpu] Introduce `gpu.dynamic_shared_memory` Op (#71546)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 16 05:42:22 PST 2023
Author: Guray Ozen
Date: 2023-11-16T14:42:17+01:00
New Revision: ea84897ba3e7727a3aa3fbd6d84b6b4ab573c70d
URL: https://github.com/llvm/llvm-project/commit/ea84897ba3e7727a3aa3fbd6d84b6b4ab573c70d
DIFF: https://github.com/llvm/llvm-project/commit/ea84897ba3e7727a3aa3fbd6d84b6b4ab573c70d.diff
LOG: [mlir][gpu] Introduce `gpu.dynamic_shared_memory` Op (#71546)
While the `gpu.launch` Op allows setting the size via the
`dynamic_shared_memory_size` argument, accessing the dynamic shared
memory is very convoluted. This PR implements the proposed Op,
`gpu.dynamic_shared_memory` that aims to simplify the utilization of
dynamic shared memory.
RFC:
https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/
**Proposal from RFC**
This PR `gpu.dynamic.shared.memory` Op to use dynamic shared memory
feature efficiently. It is is a powerful feature that enables the
allocation of shared memory at runtime with the kernel launch on the
host. Afterwards, the memory can be accessed directly from the device. I
believe similar story exists for AMDGPU.
**Current way Using Dynamic Shared Memory with MLIR**
Let me illustrate the challenges of using dynamic shared memory in MLIR
with an example below. The process involves several steps:
- memref.global 0-sized array LLVM's NVPTX backend expects
- dynamic_shared_memory_size Set the size of dynamic shared memory
- memref.get_global Access the global symbol
- reinterpret_cast and subview Many OPs for pointer arithmetic
```
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @dynamicShmem : memref<0xf16, 3> { alignment = 16 }
func.func @main() {
// Step 2. Allocate shared memory
gpu.launch blocks(...) threads(...)
dynamic_shared_memory_size %c10000 {
// Step 3. Access the global object
%shmem = memref.get_global @dynamicShmem : memref<0xf16, 3>
// Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
%4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
%5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
%6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
%7 = memref.subview %6[0, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
%8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>
// Step.5 Use
"test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
"test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
gpu.terminator
}
```
Let’s write the program above with that:
```
func.func @main() {
gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
%i = arith.constant 18 : index
// Step 1: Obtain shared memory directly
%shmem = gpu.dynamic_shared_memory : memref<?xi8, 3>
%c147456 = arith.constant 147456 : index
%c155648 = arith.constant 155648 : index
%7 = memref.view %shmem[%c147456][] : memref<?xi8, 3> to memref<64x64xf16, 3>
%8 = memref.view %shmem[%c155648][] : memref<?xi8, 3> to memref<64x64xf16, 3>
// Step 2: Utilize the shared memory
"test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
"test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
}
}
```
This PR resolves #72513
Added:
mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
Modified:
mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
mlir/include/mlir/IR/SymbolTable.h
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/test/Dialect/GPU/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 755c82d8b75c9c0..ccb9580adbd1f54 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -52,6 +52,14 @@ def GPU_Dialect : Dialect {
/// Returns the numeric value used to identify the private memory address
/// space.
static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; }
+
+ /// Return true if the given MemRefType has an address space that matches
+ /// with the gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ static bool hasWorkgroupMemoryAddressSpace(MemRefType type);
+
+ /// Return true if the given Attribute is an gpu::AddressSpaceAttr
+ /// attribute with value 'workgroup`.
+ static bool isWorkgroupMemoryAddressSpace(Attribute memorySpace);
}];
let dependentDialects = ["arith::ArithDialect"];
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 632cdd96c6d4c2b..e11c5c393648de7 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -433,6 +433,32 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
let hasVerifier = 1;
}
+def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [Pure]>
+{
+ let summary = "Get the memref for dynamic shared memory";
+
+ let description = [{
+ This operation provides a memref pointer to the start of dynamic shared
+ memory, often referred to as workgroup memory. It's important to note that
+ this dynamic shared memory needs to be allocated at kernel launch. One can
+ conveniently utilize `the dynamic_shared_memory_size` parameter of
+ `gpu.launch` for this purpose.
+
+ Examples:
+ ```mlir
+ %0 = gpu.dynamic.shared.memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %1 = memref.view %0[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>>
+ to memref<32x64xf32, #gpu.address_space<workgroup>>
+ %2 = memref.view %0[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>>
+ to memref<32x64xf32, #gpu.address_space<workgroup>>
+ ```
+ }];
+ let arguments = (ins);
+ let results = (outs Arg<MemRefRankOf<[I8], [1]>>:$resultMemref);
+ let assemblyFormat = [{ attr-dict `:` type($resultMemref) }];
+ let hasVerifier = 1;
+}
+
def LaunchIndx : AnyTypeOf<[Index, I32, I64]>;
def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 8ff8f850a9c1858..08019e77ae6af8a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -27,6 +27,9 @@
namespace mlir {
namespace NVVM {
+// Shared memory has 128-bit alignment
+constexpr int kSharedMemoryAlignmentBit = 128;
+
/// NVVM memory space identifiers.
enum NVVMMemorySpace {
/// Global memory space identifier.
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 7f21f22eba951e3..597c6a9a1d89108 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -103,6 +103,24 @@ class SymbolTable {
Nested,
};
+ /// Generate a unique symbol name. Iteratively increase uniquingCounter
+ /// and use it as a suffix for symbol names until uniqueChecker does not
+ /// detect any conflict.
+ template <unsigned N, typename UniqueChecker>
+ static SmallString<N> generateSymbolName(StringRef name,
+ UniqueChecker uniqueChecker,
+ unsigned &uniquingCounter) {
+ SmallString<N> nameBuffer(name);
+ unsigned originalLength = nameBuffer.size();
+ do {
+ nameBuffer.resize(originalLength);
+ nameBuffer += '_';
+ nameBuffer += std::to_string(uniquingCounter++);
+ } while (uniqueChecker(nameBuffer));
+
+ return nameBuffer;
+ }
+
/// Returns the name of the given symbol operation, aborting if no symbol is
/// present.
static StringAttr getSymbolName(Operation *symbol);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index a747e9742d4fb72..e79a02f931af2d5 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/ADT/StringSet.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
@@ -549,6 +550,104 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
}
+/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
+/// or uses existing symbol.
+LLVM::GlobalOp
+getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
+ Operation *moduleOp, gpu::DynamicSharedMemoryOp op,
+ const LLVMTypeConverter *typeConverter,
+ MemRefType memrefType, unsigned alignmentBit) {
+ uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+
+ FailureOr<unsigned> addressSpace =
+ typeConverter->getMemRefAddressSpace(memrefType);
+ if (failed(addressSpace)) {
+ op->emitError() << "conversion of memref memory space "
+ << memrefType.getMemorySpace()
+ << " to integer address space "
+ "failed. Consider adding memory space conversions.";
+ }
+
+ // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
+ // LLVM::GlobalOp is suitable for shared memory, return it.
+ llvm::StringSet<> existingGlobalNames;
+ for (auto globalOp :
+ moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) {
+ existingGlobalNames.insert(globalOp.getSymName());
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+ if (globalOp.getAddrSpace() == addressSpace.value() &&
+ arrayType.getNumElements() == 0 &&
+ globalOp.getAlignment().value_or(0) == alignmentByte) {
+ return globalOp;
+ }
+ }
+ }
+
+ // Step 2. Find a unique symbol name
+ unsigned uniquingCounter = 0;
+ SmallString<128> symName = SymbolTable::generateSymbolName<128>(
+ "__dynamic_shmem_",
+ [&](StringRef candidate) {
+ return existingGlobalNames.contains(candidate);
+ },
+ uniquingCounter);
+
+ // Step 3. Generate a global op
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(&moduleOp->getRegion(0).front().front());
+
+ auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
+ typeConverter->convertType(memrefType.getElementType()), 0);
+
+ return rewriter.create<LLVM::GlobalOp>(
+ op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
+ LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
+ addressSpace.value());
+}
+
+LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
+ gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ MemRefType memrefType = op.getResultMemref().getType();
+ Type elementType = typeConverter->convertType(memrefType.getElementType());
+
+ // Step 1: Generate a memref<0xi8> type
+ MemRefLayoutAttrInterface layout = {};
+ auto memrefType0sz =
+ MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+
+ // Step 2: Generate a global symbol or existing for the dynamic shared
+ // memory with memref<0xi8> type
+ LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+ LLVM::GlobalOp shmemOp = {};
+ Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
+ shmemOp = getDynamicSharedMemorySymbol(
+ rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
+
+ // Step 3. Get address of the global symbol
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(op);
+ auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
+ Type baseType = basePtr->getResultTypes().front();
+
+ // Step 4. Generate GEP using offsets
+ SmallVector<LLVM::GEPArg> gepArgs = {0};
+ Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
+ basePtr, gepArgs);
+ // Step 5. Create a memref descriptor
+ SmallVector<Value> shape, strides;
+ Value sizeBytes;
+ getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
+ sizeBytes);
+ auto memRefDescriptor = this->createMemRefDescriptor(
+ loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
+
+ // Step 5. Replace the op with memref descriptor
+ rewriter.replaceOp(op, {memRefDescriptor});
+ return success();
+}
+
void mlir::populateGpuMemorySpaceAttributeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
typeConverter.addTypeAttributeConversion(
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index bd90286494d8035..a77db4a036bad3f 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,27 @@
namespace mlir {
+/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
+/// create a 0-sized global array symbol similar as LLVM expects. It constructs
+/// a memref descriptor with these values and return it.
+struct GPUDynamicSharedMemoryOpLowering
+ : public ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp> {
+ using ConvertOpToLLVMPattern<
+ gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
+ GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
+ unsigned alignmentBit = 0)
+ : ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
+ alignmentBit(alignmentBit) {}
+
+ LogicalResult
+ matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+
+private:
+ // Alignment bit
+ unsigned alignmentBit;
+};
+
struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
GPUFuncOpLowering(const LLVMTypeConverter &converter,
unsigned allocaAddrSpace, unsigned workgroupAddrSpace,
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 935e3d2a4095003..86a77f557cb9579 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -325,6 +325,9 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
converter);
+ patterns.add<GPUDynamicSharedMemoryOpLowering>(
+ converter, NVVM::kSharedMemoryAlignmentBit);
+
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
// memory space and does not support `alloca`s with addrspace(5).
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index e0a2b93df3d1fd6..9517c053c8360ef 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -23,6 +23,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -164,17 +165,18 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
// GPUDialect
//===----------------------------------------------------------------------===//
-/// GPU memory space identifiers.
-enum GPUMemorySpace {
- /// Generic memory space identifier.
- kGenericMemorySpace = 0,
-
- /// Global memory space identifier.
- kGlobalMemorySpace = 1,
+bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
+ if (!memorySpace)
+ return false;
+ if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ return gpuAttr.getValue() == getWorkgroupAddressSpace();
+ return false;
+}
- /// Shared memory space identifier.
- kSharedMemorySpace = 3
-};
+bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
+ Attribute memorySpace = type.getMemorySpace();
+ return isWorkgroupMemoryAddressSpace(memorySpace);
+}
bool GPUDialect::isKernel(Operation *op) {
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
@@ -2047,6 +2049,28 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// DynamicSharedMemoryOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult gpu::DynamicSharedMemoryOp::verify() {
+ if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
+ return emitOpError() << "must be inside an op with symbol table";
+
+ MemRefType memrefType = getResultMemref().getType();
+ // Check address space
+ if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
+ return emitOpError() << "address space must be "
+ << gpu::AddressSpaceAttr::getMnemonic() << "<"
+ << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
+ }
+ if (memrefType.hasStaticShape()) {
+ return emitOpError() << "result memref type must be memref<?xi8, "
+ "#gpu.address_space<workgroup>>";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// GPU target options
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 7180ea432ea057d..e83d19553d62ce8 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -200,20 +200,16 @@ StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
// If the symbol was already in the table, also return.
if (symbolTable.lookup(name) == symbol)
return name;
- // If a conflict was detected, then the symbol will not have been added to
- // the symbol table. Try suffixes until we get to a unique name that works.
- SmallString<128> nameBuffer(name.getValue());
- unsigned originalLength = nameBuffer.size();
MLIRContext *context = symbol->getContext();
-
- // Iteratively try suffixes until we find one that isn't used.
- do {
- nameBuffer.resize(originalLength);
- nameBuffer += '_';
- nameBuffer += std::to_string(uniquingCounter++);
- } while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
- .second);
+ SmallString<128> nameBuffer = generateSymbolName<128>(
+ name.getValue(),
+ [&](StringRef candidate) {
+ return !symbolTable
+ .insert({StringAttr::get(context, candidate), symbol})
+ .second;
+ },
+ uniquingCounter);
setSymbolName(symbol, nameBuffer);
return getSymbolName(symbol);
}
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
new file mode 100644
index 000000000000000..fb45faaa712f7a9
--- /dev/null
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
+
+gpu.module @modules {
+ // CHECK: llvm.mlir.global internal @__dynamic_shmem__3() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+ llvm.mlir.global internal @__dynamic_shmem__0() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
+ llvm.mlir.global internal @__dynamic_shmem__1() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
+ llvm.mlir.global internal @__dynamic_shmem__2() {alignment = 16 : i64} : !llvm.array<0 x i8>
+ // CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
+ // CHECK-SAME: %[[arg0:.+]]: i64)
+ gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {
+ %c1 = arith.constant 1 : index
+ %c8192 = arith.constant 8192 : index
+ %c16384 = arith.constant 16384 : index
+ %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %shmem2 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+
+ %0 = memref.view %shmem[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+ %1 = memref.view %shmem[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ "test.use.shared.memory"(%1) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+// CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
+// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__dynamic_shmem__3 : !llvm.ptr<3>
+// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S11:.+]] = llvm.insertvalue %[[S2]], %[[S10]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S12:.+]] = llvm.insertvalue %[[S0]], %[[S11]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %[[S13]] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+// CHECK: %[[S15:.+]] = llvm.getelementptr %4[16384] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S16:.+]] = llvm.insertvalue %[[S15]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S17:.+]] = llvm.insertvalue %[[S3]], %[[S16]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S18:.+]] = llvm.insertvalue %[[S1]], %[[S17]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S19:.+]] = llvm.insertvalue %[[S2]], %[[S18]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S20:.+]] = llvm.insertvalue %[[S0]], %[[S19]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S21:.+]] = llvm.insertvalue %[[S1]], %[[S20]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S22:.+]] = builtin.unrealized_conversion_cast %[[S21]] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S22]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+ gpu.return
+ }
+
+// CHECK-LABEL: llvm.func @gpu_device_function
+ gpu.func @gpu_device_function() {
+ %c8192 = arith.constant 8192 : index
+ %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %0 = memref.view %shmem[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+// CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
+// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__dynamic_shmem__3 : !llvm.ptr<3>
+// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S11:.+]] = llvm.insertvalue %[[S2]], %[[S10]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S12:.+]] = llvm.insertvalue %[[S0]], %[[S11]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %13 : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+ gpu.return
+ }
+
+// CHECK-LABEL: llvm.func @func_device_function
+ func.func @func_device_function() {
+ %c8192 = arith.constant 8192 : index
+ %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %0 = memref.view %shmem[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+// CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
+// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__dynamic_shmem__3 : !llvm.ptr<3>
+// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S11:.+]] = llvm.insertvalue %[[S2]], %[[S10]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S12:.+]] = llvm.insertvalue %[[S0]], %[[S11]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %13 : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+ func.return
+ }
+}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 680e604151d77fd..df9921ef14d3b51 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -640,3 +640,69 @@ module {
// expected-error @+1 {{'gpu.binary' op attribute 'offloadingHandler' failed to satisfy constraint: any attribute with the `OffloadingTranslationAttrTrait` trait.}}
gpu.binary @binary <1> [#gpu.object<#nvvm.target, "">]
}
+
+// -----
+
+func.func @main() {
+ %shmemSize = arith.constant 10000 : i32
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ dynamic_shared_memory_size %shmemSize
+ {
+ // expected-error @below {{'gpu.dynamic_shared_memory' op address space must be address_space<workgroup>}}
+ %0 = gpu.dynamic_shared_memory : memref<?xi8>
+ gpu.terminator
+ }
+ return
+}
+
+
+// -----
+
+func.func @main() {
+ %shmemSize = arith.constant 8192 : i32
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ dynamic_shared_memory_size %shmemSize
+ {
+ // expected-error @below {{'gpu.dynamic_shared_memory' op result memref type must be memref<?xi8, #gpu.address_space<workgroup>>}}
+ %0 = gpu.dynamic_shared_memory : memref<1xi8, #gpu.address_space<workgroup>>
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @main(%arg0 : index) {
+ %shmemSize = arith.constant 8192 : i32
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ dynamic_shared_memory_size %shmemSize
+ {
+ // expected-error @below {{'gpu.dynamic_shared_memory' op address space must be address_space<workgroup>}}
+ %0 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<private>>
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @main(%arg0 : index) {
+ %shmemSize = arith.constant 8192 : i32
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ dynamic_shared_memory_size %shmemSize
+ {
+ // expected-error @below {{'gpu.dynamic_shared_memory' op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, #gpu.address_space<workgroup>}}
+ %0 = gpu.dynamic_shared_memory : memref<?xf32, #gpu.address_space<workgroup>>
+ gpu.terminator
+ }
+ return
+}
+
More information about the Mlir-commits
mailing list