[Mlir-commits] [mlir] [mlir][gpu] Introduce `gpu.dynamic_shared_memory` Op (PR #71546)
Guray Ozen
llvmlistbot at llvm.org
Thu Nov 16 05:41:42 PST 2023
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/71546
>From 82c405940bcab96c083d2ef7817e70865e7c0e5b Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 16:14:59 +0100
Subject: [PATCH 01/19] [mlir][gpu] Introduce `gpu.dynamic_shared_memory` Op
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
While the `gpu.launch` Op allows setting the size via the `dynamic_shared_memory_size` argument, accessing the dynamic shared memory is very convoluted. This PR implements the proposed Op, `gpu.dynamic_shared_memory` that aims to simplify the utilization of dynamic shared memory.
RFC: https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/
**Proposal from RFC**
This PR `gpu.dynamic.shared.memory` Op to use dynamic shared memory feature efficiently. It is is a powerful feature that enables the allocation of shared memory at runtime with the kernel launch on the host. Afterwards, the memory can be accessed directly from the device. I believe similar story exists for AMDGPU.
**Current way Using Dynamic Shared Memory with MLIR**
Let me illustrate the challenges of using dynamic shared memory in MLIR with an example below. The process involves several steps:
- memref.global 0-sized array LLVM's NVPTX backend expects
- dynamic_shared_memory_size Set the size of dynamic shared memory
- memref.get_global Access the global symbol
- reinterpret_cast and subview Many OPs for pointer arithmetic
```
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @dynamicShmem : memref<0xf16, 3> { alignment = 16 }
func.func @main() {
// Step 2. Allocate shared memory
gpu.launch blocks(...) threads(...)
dynamic_shared_memory_size %c10000 {
// Step 3. Access the global object
%shmem = memref.get_global @dynamicShmem : memref<0xf16, 3>
// Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
%4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
%5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
%6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
%7 = memref.subview %6[0, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
%8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>
// Step.5 Use
"test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
"test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
gpu.terminator
}
```
Let’s write the program above with that:
```
func.func @main() {
gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
%i = arith.constant 18 : index
// Step 1: Obtain shared memory directly
%shmem = gpu.dynamic_shared_memory : memref<?xi8, 3>
%c147456 = arith.constant 147456 : index
%c155648 = arith.constant 155648 : index
%7 = memref.view %shmem[%c147456][] : memref<?xi8, 3> to memref<64x64xf16, 3>
%8 = memref.view %shmem[%c155648][] : memref<?xi8, 3> to memref<64x64xf16, 3>
// Step 2: Utilize the shared memory
"test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
"test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
}
}
```
---
mlir/include/mlir/Dialect/GPU/IR/GPUBase.td | 10 +++
mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h | 13 +++
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 23 +++++
.../include/mlir/Dialect/LLVMIR/NVVMDialect.h | 3 +
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 87 +++++++++++++++++++
.../lib/Conversion/GPUCommon/GPUOpsLowering.h | 21 +++++
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 3 +
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 41 ++++++---
.../Dialect/GPU/dynamic-shared-memory.mlir | 35 ++++++++
mlir/test/Dialect/GPU/invalid.mlir | 49 +++++++++++
10 files changed, 275 insertions(+), 10 deletions(-)
create mode 100644 mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 755c82d8b75c9c0..057b507c394e80f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -52,6 +52,16 @@ def GPU_Dialect : Dialect {
/// Returns the numeric value used to identify the private memory address
/// space.
static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; }
+
+ /// Return true if the given MemRefType has an integer address
+ /// space that matches the workgroup memory address space or
+ /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ static bool hasWorkgroupMemoryAddressSpace(MemRefType type);
+
+ /// Return true if the given Attribute has an integer address
+ /// space that matches the workgroup memory address space or
+ /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ static bool isWorkgroupMemoryAddressSpace(Attribute memorySpace);
}];
let dependentDialects = ["arith::ArithDialect"];
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 14a1fac5fd255f3..286856324950eb7 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -17,6 +17,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/DLTI/Traits.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -32,6 +33,18 @@
namespace mlir {
namespace gpu {
+/// GPU memory space identifiers.
+enum GPUMemorySpace {
+ /// Generic memory space identifier.
+ kGenericMemorySpace = 0,
+
+ /// Global memory space identifier.
+ kGlobalMemorySpace = 1,
+
+ /// Shared memory space identifier.
+ kSharedMemorySpace = 3
+};
+
/// Utility class for the GPU dialect to represent triples of `Value`s
/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
struct KernelDim3 {
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 6375d35f4311295..f3a37c62d3a7465 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -433,6 +433,29 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
let hasVerifier = 1;
}
+def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [] > {
+ let summary = "Get the memref for dynamic shared memory";
+
+ let description = [{
+ This operation provides a memref pointer to the start of dynamic shared
+ memory, often referred to as workgroup memory. It's important to note that
+ this dynamic shared memory needs to be allocated at kernel launch. One can
+ conveniently utilize `the dynamic_shared_memory_size` parameter of
+ `gpu.launch` for this purpose.
+
+ Examples:
+ ```mlir
+ %0 = gpu.dynamic.shared.memory : memref<?xi8, 3>
+ %1 = memref.view %0[%c8192][] : memref<?xi8, 3> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ %2 = memref.view %0[%c16384][] : memref<?xi8, 3> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ ```
+ }];
+ let arguments = (ins );
+ let results = (outs Arg<MemRefRankOf<[I8], [1]>>:$resultMemref);
+ let assemblyFormat = [{ attr-dict `:` type($resultMemref) }];
+ let hasVerifier = 1;
+}
+
def LaunchIndx : AnyTypeOf<[Index, I32, I64]>;
def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 8ff8f850a9c1858..08019e77ae6af8a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -27,6 +27,9 @@
namespace mlir {
namespace NVVM {
+// Shared memory has 128-bit alignment
+constexpr int kSharedMemoryAlignmentBit = 128;
+
/// NVVM memory space identifiers.
enum NVVMMemorySpace {
/// Global memory space identifier.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6d2585aa30ab4c5..fbea498ee27caa7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
#include "GPUOpsLowering.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -554,6 +555,92 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
}
+/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
+/// or uses existing symbol.
+LLVM::GlobalOp getDynamicSharedMemorySymbol(
+ ConversionPatternRewriter &rewriter, gpu::DynamicSharedMemoryOp op,
+ const LLVMTypeConverter *typeConverter, MemRefType memrefType, unsigned alignmentBit) {
+ std::optional<LLVM::GlobalOp> existingGlobalOp;
+
+ LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+ assert(funcOp && "cannot find llvm.func op");
+
+ gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
+ assert(moduleOp && "cannot find gpu.module op");
+
+ // Use already generated global op if it exists
+ int index = 0;
+ std::string prefix = llvm::formatv("__shmem_{0}", funcOp.getSymName());
+ moduleOp->walk([&](LLVM::GlobalOp globalOp) {
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+ if (arrayType.getNumElements() == 0) {
+ existingGlobalOp = globalOp;
+ return WalkResult::interrupt();
+ }
+ }
+ if (globalOp.getSymName().startswith(prefix))
+ index++;
+ return WalkResult::advance();
+ });
+ if (existingGlobalOp.has_value())
+ return existingGlobalOp.value();
+
+ // Generate a new global op
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(&moduleOp.front());
+
+ auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
+ typeConverter->convertType(memrefType.getElementType()), 0);
+ std::string name = std::string(llvm::formatv("{0}_{1}", prefix, index));
+ // TODO: better alignment calculation
+ uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+ return rewriter.create<LLVM::GlobalOp>(
+ funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
+ LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignmentByte,
+ mlir::gpu::GPUMemorySpace::kSharedMemorySpace);
+}
+
+LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
+ gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ Location loc = op.getLoc();
+ MemRefType memrefType = op.getResultMemref().getType();
+ auto elementType = typeConverter->convertType(memrefType.getElementType());
+ assert(memrefType && "memref is not valid");
+
+ // Step 1: Generate a memref<0xi8> type
+ MemRefLayoutAttrInterface layout = {};
+ auto memrefType0sz = MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+
+ // Step 2: Generate a global symbol or existing for the dynamic shared
+ // memory with memref<0xi8> type
+ LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
+ rewriter, op, getTypeConverter(), memrefType0sz ,alignmentBit);
+ assert(shmemOp && "cannot find module op or failed generating global op");
+
+ // Step 3. Get address of the global symbol
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(op);
+ auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
+ Type baseType = basePtr->getResultTypes().front();
+
+ // Step 4. Generate GEP using offsets
+ SmallVector<LLVM::GEPArg> gepArgs = {0};
+ Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
+ basePtr, gepArgs);
+ // Step 5. Create a memref descriptor
+ SmallVector<Value> shape, strides;
+ Value sizeBytes;
+ getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
+ sizeBytes);
+ auto memRefDescriptor = this->createMemRefDescriptor(
+ loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
+
+ // Step 5. Replace the op with memref descriptor
+ rewriter.replaceOp(op, {memRefDescriptor});
+ return success();
+}
+
void mlir::populateGpuMemorySpaceAttributeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
typeConverter.addTypeAttributeConversion(
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index bd90286494d8035..a77db4a036bad3f 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,27 @@
namespace mlir {
+/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
+/// create a 0-sized global array symbol similar as LLVM expects. It constructs
+/// a memref descriptor with these values and return it.
+struct GPUDynamicSharedMemoryOpLowering
+ : public ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp> {
+ using ConvertOpToLLVMPattern<
+ gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
+ GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
+ unsigned alignmentBit = 0)
+ : ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
+ alignmentBit(alignmentBit) {}
+
+ LogicalResult
+ matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+
+private:
+ // Alignment bit
+ unsigned alignmentBit;
+};
+
struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
GPUFuncOpLowering(const LLVMTypeConverter &converter,
unsigned allocaAddrSpace, unsigned workgroupAddrSpace,
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 935e3d2a4095003..86a77f557cb9579 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -325,6 +325,9 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
converter);
+ patterns.add<GPUDynamicSharedMemoryOpLowering>(
+ converter, NVVM::kSharedMemoryAlignmentBit);
+
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
// memory space and does not support `alloca`s with addrspace(5).
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5eb2cadc884e151..3216e82147da907 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -164,17 +164,20 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
// GPUDialect
//===----------------------------------------------------------------------===//
-/// GPU memory space identifiers.
-enum GPUMemorySpace {
- /// Generic memory space identifier.
- kGenericMemorySpace = 0,
-
- /// Global memory space identifier.
- kGlobalMemorySpace = 1,
+bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
+ if (!memorySpace)
+ return false;
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+ return intAttr.getInt() == GPUMemorySpace::kSharedMemorySpace;
+ if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+ return gpuAttr.getValue() == getWorkgroupAddressSpace();
+ return false;
+}
- /// Shared memory space identifier.
- kSharedMemorySpace = 3
-};
+bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
+ Attribute memorySpace = type.getMemorySpace();
+ return isWorkgroupMemoryAddressSpace(memorySpace);
+}
bool GPUDialect::isKernel(Operation *op) {
UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
@@ -2024,6 +2027,24 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+//===----------------------------------------------------------------------===//
+// DynamicSharedMemoryOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult gpu::DynamicSharedMemoryOp::verify() {
+ MemRefType memrefType = getResultMemref().getType();
+ // Check address space
+ if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
+ return emitOpError() << "Address space must be "
+ << gpu::AddressSpaceAttr::getMnemonic() << "<"
+ << stringifyEnum(gpu::AddressSpace::Workgroup)
+ << "> or " << int(GPUMemorySpace::kSharedMemorySpace);
+ }
+ if(memrefType.hasStaticShape())
+ return emitOpError() << "result memref type must be memref<?xi8>";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// GPU target options
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
new file mode 100644
index 000000000000000..d3ca12b7131c469
--- /dev/null
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
+
+gpu.module @modules {
+ // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+
+ // CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
+ // CHECK-SAME: %[[arg0:.+]]: i64)
+ gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100 : index
+ %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+
+ %0 = memref.view %shmem[%c100][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+// CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
+// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_0 : !llvm.ptr<3>
+// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][100] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S11:.+]] = llvm.insertvalue %[[S2]], %[[S10]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S12:.+]] = llvm.insertvalue %[[S0]], %[[S11]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %[[S13]] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+ gpu.return
+ }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index c8c0b7d24bc3ab2..768289819cc0e09 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -640,3 +640,52 @@ module {
// expected-error @+1 {{'gpu.binary' op attribute 'offloadingHandler' failed to satisfy constraint: any attribute with the `OffloadingTranslationAttrTrait` trait.}}
gpu.binary @binary <1> [#gpu.object<#nvvm.target, "">]
}
+
+// -----
+
+func.func @main() {
+ %shmemSize = arith.constant 10000 : i32
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ dynamic_shared_memory_size %shmemSize
+ {
+ // expected-error @+1 {{op Address space must be address_space<workgroup> or 3}}
+ %0 = gpu.dynamic_shared_memory : memref<?xi8>
+ gpu.terminator
+ }
+ return
+}
+
+
+// -----
+
+func.func @main() {
+ %shmemSize = arith.constant 8192 : i32
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ dynamic_shared_memory_size %shmemSize
+ {
+ // expected-error @+1 {{result memref type must be memref<?xi8>}}
+ %0 = gpu.dynamic_shared_memory : memref<1xi8,3>
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @main(%arg0 : index) {
+ %shmemSize = arith.constant 8192 : i32
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ dynamic_shared_memory_size %shmemSize
+ {
+ // expected-error @+1 {{op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, 3>}}
+ %0 = gpu.dynamic_shared_memory : memref<?xf32,3>
+ gpu.terminator
+ }
+ return
+}
>From f7babc582ea39a01e17617790652f76939757b39 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 16:16:12 +0100
Subject: [PATCH 02/19] add nl
---
mlir/test/Dialect/GPU/dynamic-shared-memory.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index d3ca12b7131c469..a2706fa6bde7ae9 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -32,4 +32,4 @@ gpu.module @modules {
gpu.return
}
-}
\ No newline at end of file
+}
>From 8924447541382bacd05da1b9a207e9bb196409f6 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 16:18:39 +0100
Subject: [PATCH 03/19] remove todo
---
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index fbea498ee27caa7..aaee407166f581a 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -592,7 +592,7 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
typeConverter->convertType(memrefType.getElementType()), 0);
std::string name = std::string(llvm::formatv("{0}_{1}", prefix, index));
- // TODO: better alignment calculation
+
uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
return rewriter.create<LLVM::GlobalOp>(
funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
>From badeaa3e0c9375699eaa104310a3ba1fafbffe1d Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 16:21:38 +0100
Subject: [PATCH 04/19] improve the test
---
.../Dialect/GPU/dynamic-shared-memory.mlir | 21 +++++++++++++++----
1 file changed, 17 insertions(+), 4 deletions(-)
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index a2706fa6bde7ae9..04bacb16140ad97 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -7,11 +7,16 @@ gpu.module @modules {
// CHECK-SAME: %[[arg0:.+]]: i64)
gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {
%c1 = arith.constant 1 : index
- %c100 = arith.constant 100 : index
+ %c8192 = arith.constant 8192 : index
+ %c16384 = arith.constant 16384 : index
%shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %shmem2 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
- %0 = memref.view %shmem[%c100][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ %0 = memref.view %shmem[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
"test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+ %1 = memref.view %shmem[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ "test.use.shared.memory"(%1) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
// CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
@@ -20,7 +25,7 @@ gpu.module @modules {
// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_0 : !llvm.ptr<3>
// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][100] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
// CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
@@ -29,7 +34,15 @@ gpu.module @modules {
// CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %[[S13]] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
// CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
-
+// CHECK: %[[S15:.+]] = llvm.getelementptr %4[16384] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S16:.+]] = llvm.insertvalue %[[S15]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S17:.+]] = llvm.insertvalue %[[S3]], %[[S16]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S18:.+]] = llvm.insertvalue %[[S1]], %[[S17]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S19:.+]] = llvm.insertvalue %[[S2]], %[[S18]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S20:.+]] = llvm.insertvalue %[[S0]], %[[S19]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S21:.+]] = llvm.insertvalue %[[S1]], %[[S20]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S22:.+]] = builtin.unrealized_conversion_cast %[[S21]] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S22]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
gpu.return
}
}
>From 91221dfb84ebad8326a30c4c8a369a6af9bdef1b Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 17:01:36 +0100
Subject: [PATCH 05/19] fix format
---
.../lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 17 ++++++++++-------
1 file changed, 10 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index aaee407166f581a..4cc9e1c353189c6 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -557,9 +557,11 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
/// or uses existing symbol.
-LLVM::GlobalOp getDynamicSharedMemorySymbol(
- ConversionPatternRewriter &rewriter, gpu::DynamicSharedMemoryOp op,
- const LLVMTypeConverter *typeConverter, MemRefType memrefType, unsigned alignmentBit) {
+LLVM::GlobalOp
+getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
+ gpu::DynamicSharedMemoryOp op,
+ const LLVMTypeConverter *typeConverter,
+ MemRefType memrefType, unsigned alignmentBit) {
std::optional<LLVM::GlobalOp> existingGlobalOp;
LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -592,7 +594,7 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
typeConverter->convertType(memrefType.getElementType()), 0);
std::string name = std::string(llvm::formatv("{0}_{1}", prefix, index));
-
+
uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
return rewriter.create<LLVM::GlobalOp>(
funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
@@ -607,15 +609,16 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
MemRefType memrefType = op.getResultMemref().getType();
auto elementType = typeConverter->convertType(memrefType.getElementType());
assert(memrefType && "memref is not valid");
-
+
// Step 1: Generate a memref<0xi8> type
MemRefLayoutAttrInterface layout = {};
- auto memrefType0sz = MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+ auto memrefType0sz =
+ MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
// Step 2: Generate a global symbol or existing for the dynamic shared
// memory with memref<0xi8> type
LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
- rewriter, op, getTypeConverter(), memrefType0sz ,alignmentBit);
+ rewriter, op, getTypeConverter(), memrefType0sz, alignmentBit);
assert(shmemOp && "cannot find module op or failed generating global op");
// Step 3. Get address of the global symbol
>From 98ad1a2cee6c7febb19b80c488dcbdf81316cb20 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 17:20:15 +0100
Subject: [PATCH 06/19] fix format
---
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 3216e82147da907..113f581e6522234 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2040,7 +2040,7 @@ LogicalResult gpu::DynamicSharedMemoryOp::verify() {
<< stringifyEnum(gpu::AddressSpace::Workgroup)
<< "> or " << int(GPUMemorySpace::kSharedMemorySpace);
}
- if(memrefType.hasStaticShape())
+ if (memrefType.hasStaticShape())
return emitOpError() << "result memref type must be memref<?xi8>";
return success();
}
>From d8860f5f340da1660be4427bbaf18361b5267e7a Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 18:55:52 +0100
Subject: [PATCH 07/19] address @ftynse comments
---
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 16 +++++++-------
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 21 ++++++++-----------
.../Dialect/GPU/dynamic-shared-memory.mlir | 4 ++--
mlir/test/Dialect/GPU/invalid.mlir | 6 +++---
4 files changed, 23 insertions(+), 24 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f3a37c62d3a7465..c6989f4b9bbbd6b 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -439,18 +439,20 @@ def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [] > {
let description = [{
This operation provides a memref pointer to the start of dynamic shared
memory, often referred to as workgroup memory. It's important to note that
- this dynamic shared memory needs to be allocated at kernel launch. One can
- conveniently utilize `the dynamic_shared_memory_size` parameter of
- `gpu.launch` for this purpose.
+ this dynamic shared memory needs to be allocated at kernel launch. One can
+ conveniently utilize `the dynamic_shared_memory_size` parameter of
+ `gpu.launch` for this purpose.
Examples:
```mlir
- %0 = gpu.dynamic.shared.memory : memref<?xi8, 3>
- %1 = memref.view %0[%c8192][] : memref<?xi8, 3> to memref<32x64xf32, #gpu.address_space<workgroup>>
- %2 = memref.view %0[%c16384][] : memref<?xi8, 3> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ %0 = gpu.dynamic.shared.memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %1 = memref.view %0[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>>
+ to memref<32x64xf32, #gpu.address_space<workgroup>>
+ %2 = memref.view %0[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>>
+ to memref<32x64xf32, #gpu.address_space<workgroup>>
```
}];
- let arguments = (ins );
+ let arguments = (ins);
let results = (outs Arg<MemRefRankOf<[I8], [1]>>:$resultMemref);
let assemblyFormat = [{ attr-dict `:` type($resultMemref) }];
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 4cc9e1c353189c6..5049fdff3a29c33 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -562,7 +562,7 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
gpu::DynamicSharedMemoryOp op,
const LLVMTypeConverter *typeConverter,
MemRefType memrefType, unsigned alignmentBit) {
- std::optional<LLVM::GlobalOp> existingGlobalOp;
+ LLVM::GlobalOp existingGlobalOp;
LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
assert(funcOp && "cannot find llvm.func op");
@@ -570,22 +570,20 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
assert(moduleOp && "cannot find gpu.module op");
- // Use already generated global op if it exists
- int index = 0;
- std::string prefix = llvm::formatv("__shmem_{0}", funcOp.getSymName());
+ uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+ // Use already generated global op if it exists.
moduleOp->walk([&](LLVM::GlobalOp globalOp) {
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
- if (arrayType.getNumElements() == 0) {
+ if (arrayType.getNumElements() == 0 &&
+ globalOp.getAlignment().value_or(0) == alignmentByte) {
existingGlobalOp = globalOp;
return WalkResult::interrupt();
}
}
- if (globalOp.getSymName().startswith(prefix))
- index++;
return WalkResult::advance();
});
- if (existingGlobalOp.has_value())
- return existingGlobalOp.value();
+ if (existingGlobalOp)
+ return existingGlobalOp;
// Generate a new global op
OpBuilder::InsertionGuard guard(rewriter);
@@ -593,9 +591,8 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
typeConverter->convertType(memrefType.getElementType()), 0);
- std::string name = std::string(llvm::formatv("{0}_{1}", prefix, index));
+ std::string name = llvm::formatv("__shmem_{0}", funcOp.getSymName());
- uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
return rewriter.create<LLVM::GlobalOp>(
funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignmentByte,
@@ -607,7 +604,7 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
MemRefType memrefType = op.getResultMemref().getType();
- auto elementType = typeConverter->convertType(memrefType.getElementType());
+ Type elementType = typeConverter->convertType(memrefType.getElementType());
assert(memrefType && "memref is not valid");
// Step 1: Generate a memref<0xi8> type
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index 04bacb16140ad97..f300481ba178d74 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
gpu.module @modules {
- // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+ // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
// CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
// CHECK-SAME: %[[arg0:.+]]: i64)
@@ -22,7 +22,7 @@ gpu.module @modules {
// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_0 : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel : !llvm.ptr<3>
// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 768289819cc0e09..fd01f08ce552a2a 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -650,7 +650,7 @@ func.func @main() {
threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
dynamic_shared_memory_size %shmemSize
{
- // expected-error @+1 {{op Address space must be address_space<workgroup> or 3}}
+ // expected-error @below {{op Address space must be address_space<workgroup> or 3}}
%0 = gpu.dynamic_shared_memory : memref<?xi8>
gpu.terminator
}
@@ -667,7 +667,7 @@ func.func @main() {
threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
dynamic_shared_memory_size %shmemSize
{
- // expected-error @+1 {{result memref type must be memref<?xi8>}}
+ // expected-error @below {{result memref type must be memref<?xi8>}}
%0 = gpu.dynamic_shared_memory : memref<1xi8,3>
gpu.terminator
}
@@ -683,7 +683,7 @@ func.func @main(%arg0 : index) {
threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
dynamic_shared_memory_size %shmemSize
{
- // expected-error @+1 {{op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, 3>}}
+ // expected-error @below {{op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, 3>}}
%0 = gpu.dynamic_shared_memory : memref<?xf32,3>
gpu.terminator
}
>From 371cddc4d970711433dea3f67af7a6871af332b5 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 9 Nov 2023 15:10:01 +0100
Subject: [PATCH 08/19] find unique name and test
---
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 25 +++++++++++++------
.../Dialect/GPU/dynamic-shared-memory.mlir | 7 +++---
2 files changed, 22 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 5049fdff3a29c33..a31179f495dc261 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -14,6 +14,8 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Visitors.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/FormatVariadic.h"
@@ -562,8 +564,6 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
gpu::DynamicSharedMemoryOp op,
const LLVMTypeConverter *typeConverter,
MemRefType memrefType, unsigned alignmentBit) {
- LLVM::GlobalOp existingGlobalOp;
-
LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
assert(funcOp && "cannot find llvm.func op");
@@ -571,31 +571,42 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
assert(moduleOp && "cannot find gpu.module op");
uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
- // Use already generated global op if it exists.
+
+ LLVM::GlobalOp existingGlobalOp;
moduleOp->walk([&](LLVM::GlobalOp globalOp) {
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
if (arrayType.getNumElements() == 0 &&
globalOp.getAlignment().value_or(0) == alignmentByte) {
existingGlobalOp = globalOp;
- return WalkResult::interrupt();
}
}
- return WalkResult::advance();
});
if (existingGlobalOp)
return existingGlobalOp;
+ // Find unique name
+ int index = 0;
+ std::string symName, name = llvm::formatv("__shmem_{0}", funcOp.getSymName());
+ WalkResult walkResult;
+ do {
+ symName = llvm::formatv("{0}_{1}", name, index++);
+ walkResult = moduleOp->walk([&](LLVM::GlobalOp globalOp) {
+ if (globalOp.getSymName() == symName)
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ } while (walkResult.wasInterrupted());
+
// Generate a new global op
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(&moduleOp.front());
auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
typeConverter->convertType(memrefType.getElementType()), 0);
- std::string name = llvm::formatv("__shmem_{0}", funcOp.getSymName());
return rewriter.create<LLVM::GlobalOp>(
funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
- LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignmentByte,
+ LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
mlir::gpu::GPUMemorySpace::kSharedMemorySpace);
}
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index f300481ba178d74..308b5fb42e93e1e 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -1,8 +1,9 @@
// RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
gpu.module @modules {
- // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
-
+ // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_1() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+ llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
+ llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
// CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
// CHECK-SAME: %[[arg0:.+]]: i64)
gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {
@@ -22,7 +23,7 @@ gpu.module @modules {
// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_1 : !llvm.ptr<3>
// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
>From f458af8559e343b5c2fae425d9d9e3a6cbb1f168 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 9 Nov 2023 16:24:30 +0100
Subject: [PATCH 09/19] address comments
---
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 57 ++++++++++---------
.../lib/Conversion/GPUCommon/GPUOpsLowering.h | 5 +-
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 3 +-
.../Dialect/GPU/dynamic-shared-memory.mlir | 30 +++++++++-
4 files changed, 64 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index a31179f495dc261..f181c8418d9a224 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -559,45 +559,47 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
/// or uses existing symbol.
-LLVM::GlobalOp
-getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
- gpu::DynamicSharedMemoryOp op,
- const LLVMTypeConverter *typeConverter,
- MemRefType memrefType, unsigned alignmentBit) {
+LLVM::GlobalOp getDynamicSharedMemorySymbol(
+ ConversionPatternRewriter &rewriter, gpu::DynamicSharedMemoryOp op,
+ const LLVMTypeConverter *typeConverter, MemRefType memrefType,
+ unsigned alignmentBit, unsigned addressSpace) {
LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
assert(funcOp && "cannot find llvm.func op");
gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
assert(moduleOp && "cannot find gpu.module op");
+ // Step 1. Return existing global op if it exists
uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
-
- LLVM::GlobalOp existingGlobalOp;
- moduleOp->walk([&](LLVM::GlobalOp globalOp) {
- if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
- if (arrayType.getNumElements() == 0 &&
- globalOp.getAlignment().value_or(0) == alignmentByte) {
- existingGlobalOp = globalOp;
+ for (auto &innerOp : moduleOp->getRegions().front().front().getOperations()) {
+ if (auto globalOp = dyn_cast<LLVM::GlobalOp>(innerOp)) {
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+ if (globalOp.getAddrSpace() == addressSpace &&
+ arrayType.getNumElements() == 0 &&
+ globalOp.getAlignment().value_or(0) == alignmentByte) {
+ return globalOp;
+ }
}
}
- });
- if (existingGlobalOp)
- return existingGlobalOp;
+ }
- // Find unique name
+ // Step 2. Find a unique symbol name
int index = 0;
std::string symName, name = llvm::formatv("__shmem_{0}", funcOp.getSymName());
- WalkResult walkResult;
+ bool nameExist;
do {
+ nameExist = false;
symName = llvm::formatv("{0}_{1}", name, index++);
- walkResult = moduleOp->walk([&](LLVM::GlobalOp globalOp) {
- if (globalOp.getSymName() == symName)
- return WalkResult::interrupt();
- return WalkResult::advance();
- });
- } while (walkResult.wasInterrupted());
+ for (auto &innerOp :
+ moduleOp->getRegions().front().front().getOperations()) {
+ if (auto globalOp = dyn_cast<LLVM::GlobalOp>(innerOp)) {
+ if (globalOp.getSymName() == symName)
+ nameExist = true;
+ }
+ }
+ } while (nameExist);
- // Generate a new global op
+ // Step 3. Generate a global op
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(&moduleOp.front());
@@ -615,8 +617,8 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
MemRefType memrefType = op.getResultMemref().getType();
- Type elementType = typeConverter->convertType(memrefType.getElementType());
assert(memrefType && "memref is not valid");
+ Type elementType = typeConverter->convertType(memrefType.getElementType());
// Step 1: Generate a memref<0xi8> type
MemRefLayoutAttrInterface layout = {};
@@ -625,8 +627,9 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
// Step 2: Generate a global symbol or existing for the dynamic shared
// memory with memref<0xi8> type
- LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
- rewriter, op, getTypeConverter(), memrefType0sz, alignmentBit);
+ LLVM::GlobalOp shmemOp =
+ getDynamicSharedMemorySymbol(rewriter, op, getTypeConverter(),
+ memrefType0sz, alignmentBit, addressSpace);
assert(shmemOp && "cannot find module op or failed generating global op");
// Step 3. Get address of the global symbol
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index a77db4a036bad3f..0f16b69a3105608 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -22,9 +22,10 @@ struct GPUDynamicSharedMemoryOpLowering
using ConvertOpToLLVMPattern<
gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
+ unsigned addressSpace,
unsigned alignmentBit = 0)
: ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
- alignmentBit(alignmentBit) {}
+ alignmentBit(alignmentBit), addressSpace(addressSpace) {}
LogicalResult
matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
@@ -33,6 +34,8 @@ struct GPUDynamicSharedMemoryOpLowering
private:
// Alignment bit
unsigned alignmentBit;
+ // Address space of the shared memory
+ unsigned addressSpace;
};
struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 86a77f557cb9579..52f73d2432a3f2f 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -326,7 +326,8 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
converter);
patterns.add<GPUDynamicSharedMemoryOpLowering>(
- converter, NVVM::kSharedMemoryAlignmentBit);
+ converter, NVVM::NVVMMemorySpace::kSharedMemorySpace,
+ NVVM::kSharedMemoryAlignmentBit);
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index 308b5fb42e93e1e..d17a97cc9dd6cc4 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -1,9 +1,10 @@
// RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
gpu.module @modules {
- // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_1() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+ // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_2() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
+ llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_1() {alignment = 16 : i64} : !llvm.array<0 x i8>
// CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
// CHECK-SAME: %[[arg0:.+]]: i64)
gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {
@@ -23,7 +24,7 @@ gpu.module @modules {
// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_1 : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
@@ -46,4 +47,29 @@ gpu.module @modules {
// CHECK: "test.use.shared.memory"(%[[S22]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
gpu.return
}
+
+ gpu.func @device_function() {
+ %c8192 = arith.constant 8192 : index
+ %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %0 = memref.view %shmem[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+// CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
+// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
+// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S11:.+]] = llvm.insertvalue %[[S2]], %[[S10]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S12:.+]] = llvm.insertvalue %[[S0]], %[[S11]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %13 : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+ gpu.return
+ }
}
>From fa45fa90dea85fecc8921eea3a2da5d878ab2741 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 10 Nov 2023 12:52:10 +0100
Subject: [PATCH 10/19] test func.func
---
.../Dialect/GPU/dynamic-shared-memory.mlir | 29 ++++++++++++++++++-
1 file changed, 28 insertions(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index d17a97cc9dd6cc4..c54ed3a5f30f52b 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -48,7 +48,8 @@ gpu.module @modules {
gpu.return
}
- gpu.func @device_function() {
+// CHECK-LABEL: llvm.func @gpu_device_function
+ gpu.func @gpu_device_function() {
%c8192 = arith.constant 8192 : index
%shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
%0 = memref.view %shmem[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
@@ -72,4 +73,30 @@ gpu.module @modules {
gpu.return
}
+
+// CHECK-LABEL: llvm.func @func_device_function
+ func.func @func_device_function() {
+ %c8192 = arith.constant 8192 : index
+ %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+ %0 = memref.view %shmem[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+ "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+// CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
+// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
+// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S11:.+]] = llvm.insertvalue %[[S2]], %[[S10]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S12:.+]] = llvm.insertvalue %[[S0]], %[[S11]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %13 : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+ func.return
+ }
}
>From 75c29e7f4921a1cf65869af297178fd3dfea28a6 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 10 Nov 2023 14:44:42 +0100
Subject: [PATCH 11/19] address @ftynse comments
---
mlir/include/mlir/IR/SymbolTable.h | 18 ++++++++++++
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 28 +++++++++----------
mlir/lib/IR/SymbolTable.cpp | 20 ++++++-------
.../Dialect/GPU/dynamic-shared-memory.mlir | 14 +++++-----
4 files changed, 46 insertions(+), 34 deletions(-)
diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 7f21f22eba951e3..597c6a9a1d89108 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -103,6 +103,24 @@ class SymbolTable {
Nested,
};
+ /// Generate a unique symbol name. Iteratively increase uniquingCounter
+ /// and use it as a suffix for symbol names until uniqueChecker does not
+ /// detect any conflict.
+ template <unsigned N, typename UniqueChecker>
+ static SmallString<N> generateSymbolName(StringRef name,
+ UniqueChecker uniqueChecker,
+ unsigned &uniquingCounter) {
+ SmallString<N> nameBuffer(name);
+ unsigned originalLength = nameBuffer.size();
+ do {
+ nameBuffer.resize(originalLength);
+ nameBuffer += '_';
+ nameBuffer += std::to_string(uniquingCounter++);
+ } while (uniqueChecker(nameBuffer));
+
+ return nameBuffer;
+ }
+
/// Returns the name of the given symbol operation, aborting if no symbol is
/// present.
static StringAttr getSymbolName(Operation *symbol);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index f181c8418d9a224..7be79e44e0f6945 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Visitors.h"
#include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/ADT/StringSet.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
@@ -569,10 +570,14 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
assert(moduleOp && "cannot find gpu.module op");
- // Step 1. Return existing global op if it exists
uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+
+ // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
+ // LLVM::GlobalOp is suitable for shared memory, return it.
+ llvm::StringSet<> existingGlobalNames;
for (auto &innerOp : moduleOp->getRegions().front().front().getOperations()) {
if (auto globalOp = dyn_cast<LLVM::GlobalOp>(innerOp)) {
+ existingGlobalNames.insert(globalOp.getSymName());
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
if (globalOp.getAddrSpace() == addressSpace &&
arrayType.getNumElements() == 0 &&
@@ -584,20 +589,13 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
}
// Step 2. Find a unique symbol name
- int index = 0;
- std::string symName, name = llvm::formatv("__shmem_{0}", funcOp.getSymName());
- bool nameExist;
- do {
- nameExist = false;
- symName = llvm::formatv("{0}_{1}", name, index++);
- for (auto &innerOp :
- moduleOp->getRegions().front().front().getOperations()) {
- if (auto globalOp = dyn_cast<LLVM::GlobalOp>(innerOp)) {
- if (globalOp.getSymName() == symName)
- nameExist = true;
- }
- }
- } while (nameExist);
+ unsigned uniquingCounter = 0;
+ SmallString<128> symName = SymbolTable::generateSymbolName<128>(
+ "__dynamic_shmem_",
+ [&](StringRef candidate) {
+ return existingGlobalNames.contains(candidate);
+ },
+ uniquingCounter);
// Step 3. Generate a global op
OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 7180ea432ea057d..e83d19553d62ce8 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -200,20 +200,16 @@ StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
// If the symbol was already in the table, also return.
if (symbolTable.lookup(name) == symbol)
return name;
- // If a conflict was detected, then the symbol will not have been added to
- // the symbol table. Try suffixes until we get to a unique name that works.
- SmallString<128> nameBuffer(name.getValue());
- unsigned originalLength = nameBuffer.size();
MLIRContext *context = symbol->getContext();
-
- // Iteratively try suffixes until we find one that isn't used.
- do {
- nameBuffer.resize(originalLength);
- nameBuffer += '_';
- nameBuffer += std::to_string(uniquingCounter++);
- } while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
- .second);
+ SmallString<128> nameBuffer = generateSymbolName<128>(
+ name.getValue(),
+ [&](StringRef candidate) {
+ return !symbolTable
+ .insert({StringAttr::get(context, candidate), symbol})
+ .second;
+ },
+ uniquingCounter);
setSymbolName(symbol, nameBuffer);
return getSymbolName(symbol);
}
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index c54ed3a5f30f52b..fb45faaa712f7a9 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -1,10 +1,10 @@
// RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
gpu.module @modules {
- // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_2() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
- llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
- llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
- llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_1() {alignment = 16 : i64} : !llvm.array<0 x i8>
+ // CHECK: llvm.mlir.global internal @__dynamic_shmem__3() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+ llvm.mlir.global internal @__dynamic_shmem__0() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
+ llvm.mlir.global internal @__dynamic_shmem__1() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
+ llvm.mlir.global internal @__dynamic_shmem__2() {alignment = 16 : i64} : !llvm.array<0 x i8>
// CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
// CHECK-SAME: %[[arg0:.+]]: i64)
gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {
@@ -24,7 +24,7 @@ gpu.module @modules {
// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__dynamic_shmem__3 : !llvm.ptr<3>
// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
@@ -58,7 +58,7 @@ gpu.module @modules {
// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__dynamic_shmem__3 : !llvm.ptr<3>
// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
@@ -84,7 +84,7 @@ gpu.module @modules {
// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__dynamic_shmem__3 : !llvm.ptr<3>
// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
>From dbe76be3ea4c82d685bb7823ff972ca491b2e015 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 10 Nov 2023 15:05:50 +0100
Subject: [PATCH 12/19] use `getOps`
---
.../lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 17 ++++++++---------
1 file changed, 8 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 7be79e44e0f6945..00cd8010015da81 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -575,15 +575,14 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
// Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
// LLVM::GlobalOp is suitable for shared memory, return it.
llvm::StringSet<> existingGlobalNames;
- for (auto &innerOp : moduleOp->getRegions().front().front().getOperations()) {
- if (auto globalOp = dyn_cast<LLVM::GlobalOp>(innerOp)) {
- existingGlobalNames.insert(globalOp.getSymName());
- if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
- if (globalOp.getAddrSpace() == addressSpace &&
- arrayType.getNumElements() == 0 &&
- globalOp.getAlignment().value_or(0) == alignmentByte) {
- return globalOp;
- }
+ for (auto globalOp :
+ moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) {
+ existingGlobalNames.insert(globalOp.getSymName());
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+ if (globalOp.getAddrSpace() == addressSpace &&
+ arrayType.getNumElements() == 0 &&
+ globalOp.getAlignment().value_or(0) == alignmentByte) {
+ return globalOp;
}
}
}
>From 7913f41b145e1235537f9cc8f8b7646e40131eff Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Sat, 11 Nov 2023 12:42:30 +0100
Subject: [PATCH 13/19] Support ModuleOp, add verifier for parent op
---
mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h | 2 ++
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 5 ++-
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 35 +++++++++++--------
mlir/test/Dialect/GPU/invalid.mlir | 5 +++
4 files changed, 31 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 286856324950eb7..ecf2ca6cde8bd6a 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -16,7 +16,9 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/DLTI/Traits.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index c6989f4b9bbbd6b..f98be24869726f2 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -433,7 +433,10 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
let hasVerifier = 1;
}
-def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [] > {
+def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory",
+ [ParentOneOf<["GPUFuncOp", "LaunchOp",
+ "mlir::LLVM::LLVMFuncOp", "mlir::func::FuncOp"]>]>
+{
let summary = "Get the memref for dynamic shared memory";
let description = [{
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 00cd8010015da81..8acecda6c93df8b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -560,15 +560,11 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
/// or uses existing symbol.
+template <typename ModuleTy>
LLVM::GlobalOp getDynamicSharedMemorySymbol(
- ConversionPatternRewriter &rewriter, gpu::DynamicSharedMemoryOp op,
- const LLVMTypeConverter *typeConverter, MemRefType memrefType,
- unsigned alignmentBit, unsigned addressSpace) {
- LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
- assert(funcOp && "cannot find llvm.func op");
-
- gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
- assert(moduleOp && "cannot find gpu.module op");
+ ConversionPatternRewriter &rewriter, ModuleTy moduleOp,
+ gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter,
+ MemRefType memrefType, unsigned alignmentBit, unsigned addressSpace) {
uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
@@ -576,7 +572,7 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
// LLVM::GlobalOp is suitable for shared memory, return it.
llvm::StringSet<> existingGlobalNames;
for (auto globalOp :
- moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) {
+ moduleOp->getRegion(0).front().template getOps<LLVM::GlobalOp>()) {
existingGlobalNames.insert(globalOp.getSymName());
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
if (globalOp.getAddrSpace() == addressSpace &&
@@ -604,7 +600,7 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
typeConverter->convertType(memrefType.getElementType()), 0);
return rewriter.create<LLVM::GlobalOp>(
- funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
+ op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
mlir::gpu::GPUMemorySpace::kSharedMemorySpace);
}
@@ -614,7 +610,6 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
MemRefType memrefType = op.getResultMemref().getType();
- assert(memrefType && "memref is not valid");
Type elementType = typeConverter->convertType(memrefType.getElementType());
// Step 1: Generate a memref<0xi8> type
@@ -624,10 +619,20 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
// Step 2: Generate a global symbol or existing for the dynamic shared
// memory with memref<0xi8> type
- LLVM::GlobalOp shmemOp =
- getDynamicSharedMemorySymbol(rewriter, op, getTypeConverter(),
- memrefType0sz, alignmentBit, addressSpace);
- assert(shmemOp && "cannot find module op or failed generating global op");
+ LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+ LLVM::GlobalOp shmemOp = {};
+ if (gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>()) {
+ shmemOp =
+ getDynamicSharedMemorySymbol(rewriter, moduleOp, op, getTypeConverter(),
+ memrefType0sz, alignmentBit, addressSpace);
+ } else if (ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>()) {
+ shmemOp =
+ getDynamicSharedMemorySymbol(rewriter, moduleOp, op, getTypeConverter(),
+ memrefType0sz, alignmentBit, addressSpace);
+ }
+ if (!shmemOp) {
+ return rewriter.notifyMatchFailure(op, "failed generating global op");
+ }
// Step 3. Get address of the global symbol
OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index fd01f08ce552a2a..7dd16c529bc30e6 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -689,3 +689,8 @@ func.func @main(%arg0 : index) {
}
return
}
+
+// -----
+
+// expected-error @below {{op expects parent op to be one of 'gpu.func, gpu.launch, llvm.func, func.func'}}
+%0 = gpu.dynamic_shared_memory : memref<?xi8,3>
>From 4f6e83da62f2f32b77c7e273686f6d4b1bcad28b Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 13 Nov 2023 11:23:24 +0100
Subject: [PATCH 14/19] Add `[Pure]`, simplify with
`getParentWithTrait<OpTrait::SymbolTable>()`
---
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 4 +---
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 21 ++++++-------------
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 4 ++++
mlir/test/Dialect/GPU/invalid.mlir | 4 ----
4 files changed, 11 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f98be24869726f2..7c4092f8a03f972 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -433,9 +433,7 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
let hasVerifier = 1;
}
-def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory",
- [ParentOneOf<["GPUFuncOp", "LaunchOp",
- "mlir::LLVM::LLVMFuncOp", "mlir::func::FuncOp"]>]>
+def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [Pure]>
{
let summary = "Get the memref for dynamic shared memory";
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 8acecda6c93df8b..6e885f9f1fe5ace 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -560,9 +560,8 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
/// or uses existing symbol.
-template <typename ModuleTy>
LLVM::GlobalOp getDynamicSharedMemorySymbol(
- ConversionPatternRewriter &rewriter, ModuleTy moduleOp,
+ ConversionPatternRewriter &rewriter, Operation *moduleOp,
gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter,
MemRefType memrefType, unsigned alignmentBit, unsigned addressSpace) {
@@ -594,7 +593,7 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
// Step 3. Generate a global op
OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(&moduleOp.front());
+ rewriter.setInsertionPoint(&moduleOp->getRegion(0).front().front());
auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
typeConverter->convertType(memrefType.getElementType()), 0);
@@ -621,18 +620,10 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
// memory with memref<0xi8> type
LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
LLVM::GlobalOp shmemOp = {};
- if (gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>()) {
- shmemOp =
- getDynamicSharedMemorySymbol(rewriter, moduleOp, op, getTypeConverter(),
- memrefType0sz, alignmentBit, addressSpace);
- } else if (ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>()) {
- shmemOp =
- getDynamicSharedMemorySymbol(rewriter, moduleOp, op, getTypeConverter(),
- memrefType0sz, alignmentBit, addressSpace);
- }
- if (!shmemOp) {
- return rewriter.notifyMatchFailure(op, "failed generating global op");
- }
+ Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
+ shmemOp =
+ getDynamicSharedMemorySymbol(rewriter, moduleOp, op, getTypeConverter(),
+ memrefType0sz, alignmentBit, addressSpace);
// Step 3. Get address of the global symbol
OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 113f581e6522234..811568b1a6cdac0 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -23,6 +23,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -2032,6 +2033,9 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
//===----------------------------------------------------------------------===//
LogicalResult gpu::DynamicSharedMemoryOp::verify() {
+ if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
+ return emitOpError() << "must be inside an op with symbol table";
+
MemRefType memrefType = getResultMemref().getType();
// Check address space
if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 7dd16c529bc30e6..af36a8efd7ece4f 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -690,7 +690,3 @@ func.func @main(%arg0 : index) {
return
}
-// -----
-
-// expected-error @below {{op expects parent op to be one of 'gpu.func, gpu.launch, llvm.func, func.func'}}
-%0 = gpu.dynamic_shared_memory : memref<?xi8,3>
>From c2e67deefc3f22d18cb97bb45dc2ffeeef0b6386 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Mon, 13 Nov 2023 16:03:40 +0100
Subject: [PATCH 15/19] remove unused headers
---
mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h | 3 ---
1 file changed, 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index ecf2ca6cde8bd6a..82863b7715162f2 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -16,10 +16,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/DLTI/Traits.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
>From 0c4456404aec13d06a66be70784b2663eed390e2 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 14 Nov 2023 11:30:21 +0100
Subject: [PATCH 16/19] remove `GPUMemorySpace` enum. Improve error messages
---
mlir/include/mlir/Dialect/GPU/IR/GPUBase.td | 10 +++----
mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h | 12 ---------
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 2 +-
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 23 ++++++++--------
mlir/test/Dialect/GPU/invalid.mlir | 26 +++++++++++++++----
5 files changed, 38 insertions(+), 35 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 057b507c394e80f..ccb9580adbd1f54 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -53,14 +53,12 @@ def GPU_Dialect : Dialect {
/// space.
static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; }
- /// Return true if the given MemRefType has an integer address
- /// space that matches the workgroup memory address space or
- /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ /// Return true if the given MemRefType has an address space that matches
+ /// with the gpu::AddressSpaceAttr attribute with value 'workgroup`.
static bool hasWorkgroupMemoryAddressSpace(MemRefType type);
- /// Return true if the given Attribute has an integer address
- /// space that matches the workgroup memory address space or
- /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+ /// Return true if the given Attribute is an gpu::AddressSpaceAttr
+ /// attribute with value 'workgroup`.
static bool isWorkgroupMemoryAddressSpace(Attribute memorySpace);
}];
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 82863b7715162f2..14a1fac5fd255f3 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -32,18 +32,6 @@
namespace mlir {
namespace gpu {
-/// GPU memory space identifiers.
-enum GPUMemorySpace {
- /// Generic memory space identifier.
- kGenericMemorySpace = 0,
-
- /// Global memory space identifier.
- kGlobalMemorySpace = 1,
-
- /// Shared memory space identifier.
- kSharedMemorySpace = 3
-};
-
/// Utility class for the GPU dialect to represent triples of `Value`s
/// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
struct KernelDim3 {
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6e885f9f1fe5ace..fca757f4f03c763 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -601,7 +601,7 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
return rewriter.create<LLVM::GlobalOp>(
op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
- mlir::gpu::GPUMemorySpace::kSharedMemorySpace);
+ addressSpace);
}
LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 811568b1a6cdac0..83dde1aaab8d75f 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -168,8 +168,6 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
if (!memorySpace)
return false;
- if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
- return intAttr.getInt() == GPUMemorySpace::kSharedMemorySpace;
if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
return gpuAttr.getValue() == getWorkgroupAddressSpace();
return false;
@@ -2039,13 +2037,14 @@ LogicalResult gpu::DynamicSharedMemoryOp::verify() {
MemRefType memrefType = getResultMemref().getType();
// Check address space
if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
- return emitOpError() << "Address space must be "
+ return emitOpError() << "address space must be "
<< gpu::AddressSpaceAttr::getMnemonic() << "<"
- << stringifyEnum(gpu::AddressSpace::Workgroup)
- << "> or " << int(GPUMemorySpace::kSharedMemorySpace);
+ << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
+ }
+ if (memrefType.hasStaticShape()) {
+ return emitOpError() << "result memref type must be memref<?xi8, "
+ "#gpu.address_space<workgroup>>";
}
- if (memrefType.hasStaticShape())
- return emitOpError() << "result memref type must be memref<?xi8>";
return success();
}
@@ -2093,10 +2092,12 @@ TargetOptions::tokenizeCmdOptions() const {
std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
llvm::StringSaver stringSaver(options.first);
StringRef opts = cmdOptions;
- // For a correct tokenization of the command line options `opts` must be
- // unquoted, otherwise the tokenization function returns a single string: the
- // unquoted `cmdOptions` -which is not the desired behavior.
- // Remove any quotes if they are at the beginning and end of the string:
+ // For a correct tokenization of the command line
+ // options `opts` must be unquoted, otherwise the
+ // tokenization function returns a single string:
+ // the unquoted `cmdOptions` -which is not the
+ // desired behavior. Remove any quotes if they are
+ // at the beginning and end of the string:
if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
opts.consume_front("\""), opts.consume_back("\"");
if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index af36a8efd7ece4f..a0b702eeeb87cb0 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -650,7 +650,7 @@ func.func @main() {
threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
dynamic_shared_memory_size %shmemSize
{
- // expected-error @below {{op Address space must be address_space<workgroup> or 3}}
+ // expected-error @below {{'gpu.dynamic_shared_memory' op address space must be address_space<workgroup>}}
%0 = gpu.dynamic_shared_memory : memref<?xi8>
gpu.terminator
}
@@ -667,8 +667,8 @@ func.func @main() {
threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
dynamic_shared_memory_size %shmemSize
{
- // expected-error @below {{result memref type must be memref<?xi8>}}
- %0 = gpu.dynamic_shared_memory : memref<1xi8,3>
+ // expected-error @below {{'gpu.dynamic_shared_memory' op result memref type must be memref<?xi8, #gpu.address_space<workgroup>>}}
+ %0 = gpu.dynamic_shared_memory : memref<1xi8, #gpu.address_space<workgroup>>
gpu.terminator
}
return
@@ -683,8 +683,24 @@ func.func @main(%arg0 : index) {
threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
dynamic_shared_memory_size %shmemSize
{
- // expected-error @below {{op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, 3>}}
- %0 = gpu.dynamic_shared_memory : memref<?xf32,3>
+ // expected-error @below {{'gpu.dynamic_shared_memory' op address space must be address_space<workgroup>}}
+ %0 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<private>>
+ gpu.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @main(%arg0 : index) {
+ %shmemSize = arith.constant 8192 : i32
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
+ dynamic_shared_memory_size %shmemSize
+ {
+ // expected-error @below {{'gpu.dynamic_shared_memory' op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, #gpu.address_space<workgroup>}}
+ %0 = gpu.dynamic_shared_memory : memref<?xf32, #gpu.address_space<workgroup>>
gpu.terminator
}
return
>From 46a07f073a2667ee601a55dac900afc3e0c14819 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 14 Nov 2023 12:43:12 +0100
Subject: [PATCH 17/19] Use typeConverter for the address space
---
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 28 ++++++++++++-------
.../lib/Conversion/GPUCommon/GPUOpsLowering.h | 5 +---
.../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 3 +-
3 files changed, 20 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index fca757f4f03c763..b8e85f10069f713 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -560,13 +560,22 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
/// or uses existing symbol.
-LLVM::GlobalOp getDynamicSharedMemorySymbol(
- ConversionPatternRewriter &rewriter, Operation *moduleOp,
- gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter,
- MemRefType memrefType, unsigned alignmentBit, unsigned addressSpace) {
-
+LLVM::GlobalOp
+getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
+ Operation *moduleOp, gpu::DynamicSharedMemoryOp op,
+ const LLVMTypeConverter *typeConverter,
+ MemRefType memrefType, unsigned alignmentBit) {
uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+ FailureOr<unsigned> addressSpace =
+ typeConverter->getMemRefAddressSpace(memrefType);
+ if (failed(addressSpace)) {
+ op->emitError() << "conversion of memref memory space "
+ << memrefType.getMemorySpace()
+ << " to integer address space "
+ "failed. Consider adding memory space conversions.";
+ }
+
// Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
// LLVM::GlobalOp is suitable for shared memory, return it.
llvm::StringSet<> existingGlobalNames;
@@ -574,7 +583,7 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
moduleOp->getRegion(0).front().template getOps<LLVM::GlobalOp>()) {
existingGlobalNames.insert(globalOp.getSymName());
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
- if (globalOp.getAddrSpace() == addressSpace &&
+ if (globalOp.getAddrSpace() == addressSpace.value() &&
arrayType.getNumElements() == 0 &&
globalOp.getAlignment().value_or(0) == alignmentByte) {
return globalOp;
@@ -601,7 +610,7 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
return rewriter.create<LLVM::GlobalOp>(
op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
- addressSpace);
+ addressSpace.value());
}
LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
@@ -621,9 +630,8 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
LLVM::GlobalOp shmemOp = {};
Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
- shmemOp =
- getDynamicSharedMemorySymbol(rewriter, moduleOp, op, getTypeConverter(),
- memrefType0sz, alignmentBit, addressSpace);
+ shmemOp = getDynamicSharedMemorySymbol(
+ rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
// Step 3. Get address of the global symbol
OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 0f16b69a3105608..a77db4a036bad3f 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -22,10 +22,9 @@ struct GPUDynamicSharedMemoryOpLowering
using ConvertOpToLLVMPattern<
gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
- unsigned addressSpace,
unsigned alignmentBit = 0)
: ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
- alignmentBit(alignmentBit), addressSpace(addressSpace) {}
+ alignmentBit(alignmentBit) {}
LogicalResult
matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
@@ -34,8 +33,6 @@ struct GPUDynamicSharedMemoryOpLowering
private:
// Alignment bit
unsigned alignmentBit;
- // Address space of the shared memory
- unsigned addressSpace;
};
struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 52f73d2432a3f2f..86a77f557cb9579 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -326,8 +326,7 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
converter);
patterns.add<GPUDynamicSharedMemoryOpLowering>(
- converter, NVVM::NVVMMemorySpace::kSharedMemorySpace,
- NVVM::kSharedMemoryAlignmentBit);
+ converter, NVVM::kSharedMemoryAlignmentBit);
// Explicitly drop memory space when lowering private memory
// attributions since NVVM models it as `alloca`s in the default
>From 8d150e6e29e525662de1a45f49ba3f60f892dbae Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 14 Nov 2023 14:43:35 +0100
Subject: [PATCH 18/19] remove unused headers
---
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 3 ---
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 10 ++++------
2 files changed, 4 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index b8e85f10069f713..9cb9eea013498ff 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,13 +9,10 @@
#include "GPUOpsLowering.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/SymbolTable.h"
-#include "mlir/IR/Visitors.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/FormatVariadic.h"
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 83dde1aaab8d75f..38cce906166fdbb 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2092,12 +2092,10 @@ TargetOptions::tokenizeCmdOptions() const {
std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> options;
llvm::StringSaver stringSaver(options.first);
StringRef opts = cmdOptions;
- // For a correct tokenization of the command line
- // options `opts` must be unquoted, otherwise the
- // tokenization function returns a single string:
- // the unquoted `cmdOptions` -which is not the
- // desired behavior. Remove any quotes if they are
- // at the beginning and end of the string:
+ // For a correct tokenization of the command line options `opts` must be
+ // unquoted, otherwise the tokenization function returns a single string: the
+ // unquoted `cmdOptions` -which is not the desired behavior.
+ // Remove any quotes if they are at the beginning and end of the string:
if (!opts.empty() && opts.front() == '"' && opts.back() == '"')
opts.consume_front("\""), opts.consume_back("\"");
if (!opts.empty() && opts.front() == '\'' && opts.back() == '\'')
>From 9b7a0216efc32a43271d38c46d07125e883c937a Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 16 Nov 2023 14:40:57 +0100
Subject: [PATCH 19/19] remove `template`
---
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 9cb9eea013498ff..46bf40ef6960505 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -577,7 +577,7 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
// LLVM::GlobalOp is suitable for shared memory, return it.
llvm::StringSet<> existingGlobalNames;
for (auto globalOp :
- moduleOp->getRegion(0).front().template getOps<LLVM::GlobalOp>()) {
+ moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) {
existingGlobalNames.insert(globalOp.getSymName());
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
if (globalOp.getAddrSpace() == addressSpace.value() &&
More information about the Mlir-commits
mailing list