[Mlir-commits] [mlir] [mlir][gpu] Introduce `gpu.dynamic_shared_memory` Op (PR #71546)

Guray Ozen llvmlistbot at llvm.org
Fri Nov 10 06:06:12 PST 2023


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/71546

>From 82c405940bcab96c083d2ef7817e70865e7c0e5b Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 16:14:59 +0100
Subject: [PATCH 01/12] [mlir][gpu] Introduce `gpu.dynamic_shared_memory` Op
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

While the `gpu.launch` Op allows setting the size via the `dynamic_shared_memory_size` argument, accessing the dynamic shared memory is very convoluted. This PR implements the proposed Op, `gpu.dynamic_shared_memory` that aims to simplify the utilization of dynamic shared memory.

RFC: https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/

**Proposal from RFC**
This PR `gpu.dynamic.shared.memory` Op to use dynamic shared memory feature efficiently. It is is a powerful feature that enables the allocation of shared memory at runtime with the kernel launch on the host. Afterwards, the memory can be accessed directly from the device. I believe similar story exists for AMDGPU.

**Current way Using Dynamic Shared Memory with MLIR**

Let me illustrate the challenges of using dynamic shared memory in MLIR with an example below. The process involves several steps:
- memref.global 0-sized array LLVM's NVPTX backend expects
- dynamic_shared_memory_size Set the size of dynamic shared memory
- memref.get_global Access the global symbol
- reinterpret_cast and subview Many OPs for pointer arithmetic

```
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @dynamicShmem  : memref<0xf16, 3> { alignment = 16 }
func.func @main() {
  // Step 2. Allocate shared memory
  gpu.launch blocks(...) threads(...)
    dynamic_shared_memory_size %c10000 {
    // Step 3. Access the global object
    %shmem = memref.get_global @dynamicShmem : memref<0xf16, 3>
    // Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
    %4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
    %5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
    %6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
    %7 = memref.subview %6[0, 0][64, 64][1,1]  : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
    %8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>
    // Step.5 Use
    "test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
    "test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
    gpu.terminator
  }
```

Let’s write the program above with that:

```
func.func @main() {
    gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
    	%i = arith.constant 18 : index
        // Step 1: Obtain shared memory directly
        %shmem = gpu.dynamic_shared_memory : memref<?xi8, 3>
        %c147456 = arith.constant 147456 : index
        %c155648 = arith.constant 155648 : index
        %7 = memref.view %shmem[%c147456][] : memref<?xi8, 3> to memref<64x64xf16, 3>
        %8 = memref.view %shmem[%c155648][] : memref<?xi8, 3> to memref<64x64xf16, 3>

        // Step 2: Utilize the shared memory
        "test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
        "test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
    }
}
```
---
 mlir/include/mlir/Dialect/GPU/IR/GPUBase.td   | 10 +++
 mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h | 13 +++
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    | 23 +++++
 .../include/mlir/Dialect/LLVMIR/NVVMDialect.h |  3 +
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 87 +++++++++++++++++++
 .../lib/Conversion/GPUCommon/GPUOpsLowering.h | 21 +++++
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        |  3 +
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 41 ++++++---
 .../Dialect/GPU/dynamic-shared-memory.mlir    | 35 ++++++++
 mlir/test/Dialect/GPU/invalid.mlir            | 49 +++++++++++
 10 files changed, 275 insertions(+), 10 deletions(-)
 create mode 100644 mlir/test/Dialect/GPU/dynamic-shared-memory.mlir

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 755c82d8b75c9c0..057b507c394e80f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -52,6 +52,16 @@ def GPU_Dialect : Dialect {
     /// Returns the numeric value used to identify the private memory address
     /// space.
     static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; }
+    
+    /// Return true if the given MemRefType has an integer address
+    /// space that matches the workgroup memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool hasWorkgroupMemoryAddressSpace(MemRefType type);
+
+    /// Return true if the given Attribute has an integer address
+    /// space that matches the workgroup memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool isWorkgroupMemoryAddressSpace(Attribute memorySpace);  
   }];
 
   let dependentDialects = ["arith::ArithDialect"];
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 14a1fac5fd255f3..286856324950eb7 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -17,6 +17,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/DLTI/Traits.h"
 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -32,6 +33,18 @@
 namespace mlir {
 namespace gpu {
 
+/// GPU memory space identifiers.
+enum GPUMemorySpace {
+  /// Generic memory space identifier.
+  kGenericMemorySpace = 0,
+
+  /// Global memory space identifier.
+  kGlobalMemorySpace = 1,
+
+  /// Shared memory space identifier.
+  kSharedMemorySpace = 3
+};
+
 /// Utility class for the GPU dialect to represent triples of `Value`s
 /// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
 struct KernelDim3 {
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 6375d35f4311295..f3a37c62d3a7465 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -433,6 +433,29 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
   let hasVerifier = 1;
 }
 
+def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [] > {
+  let summary = "Get the memref for dynamic shared memory";
+  
+  let description = [{
+    This operation provides a memref pointer to the start of dynamic shared 
+    memory, often referred to as workgroup memory. It's important to note that
+     this dynamic shared memory needs to be allocated at kernel launch. One can 
+     conveniently utilize `the dynamic_shared_memory_size` parameter of 
+     `gpu.launch` for this purpose.
+   
+    Examples: 
+    ```mlir        
+    %0 = gpu.dynamic.shared.memory : memref<?xi8, 3>
+    %1 = memref.view %0[%c8192][] : memref<?xi8, 3> to memref<32x64xf32, #gpu.address_space<workgroup>>
+    %2 = memref.view %0[%c16384][] : memref<?xi8, 3> to memref<32x64xf32, #gpu.address_space<workgroup>>
+    ```
+  }];  
+  let arguments = (ins );
+  let results = (outs Arg<MemRefRankOf<[I8], [1]>>:$resultMemref);
+  let assemblyFormat = [{ attr-dict `:` type($resultMemref) }];
+  let hasVerifier = 1;
+}
+
 def LaunchIndx : AnyTypeOf<[Index, I32, I64]>;
 
 def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 8ff8f850a9c1858..08019e77ae6af8a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -27,6 +27,9 @@
 namespace mlir {
 namespace NVVM {
 
+// Shared memory has 128-bit alignment
+constexpr int kSharedMemoryAlignmentBit = 128;
+
 /// NVVM memory space identifiers.
 enum NVVMMemorySpace {
   /// Global memory space identifier.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6d2585aa30ab4c5..fbea498ee27caa7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
 #include "GPUOpsLowering.h"
 
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -554,6 +555,92 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
   return IntegerAttr::get(IntegerType::get(ctx, 64), space);
 }
 
+/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
+/// or uses existing symbol.
+LLVM::GlobalOp getDynamicSharedMemorySymbol(
+    ConversionPatternRewriter &rewriter, gpu::DynamicSharedMemoryOp op,
+    const LLVMTypeConverter *typeConverter, MemRefType memrefType, unsigned alignmentBit) {
+  std::optional<LLVM::GlobalOp> existingGlobalOp;
+
+  LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+  assert(funcOp && "cannot find llvm.func op");
+
+  gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
+  assert(moduleOp && "cannot find gpu.module op");
+
+  // Use already generated global op if it exists
+  int index = 0;
+  std::string prefix = llvm::formatv("__shmem_{0}", funcOp.getSymName());
+  moduleOp->walk([&](LLVM::GlobalOp globalOp) {
+    if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+      if (arrayType.getNumElements() == 0) {
+        existingGlobalOp = globalOp;
+        return WalkResult::interrupt();
+      }
+    }
+    if (globalOp.getSymName().startswith(prefix))
+      index++;
+    return WalkResult::advance();
+  });
+  if (existingGlobalOp.has_value())
+    return existingGlobalOp.value();
+
+  // Generate a new global op
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(&moduleOp.front());
+
+  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
+      typeConverter->convertType(memrefType.getElementType()), 0);
+  std::string name = std::string(llvm::formatv("{0}_{1}", prefix, index));
+  // TODO: better alignment calculation
+  uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+  return rewriter.create<LLVM::GlobalOp>(
+      funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
+      LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignmentByte,
+      mlir::gpu::GPUMemorySpace::kSharedMemorySpace);
+}
+
+LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
+    gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  MemRefType memrefType = op.getResultMemref().getType();
+  auto elementType = typeConverter->convertType(memrefType.getElementType());
+  assert(memrefType && "memref is not valid");
+  
+  // Step 1: Generate a memref<0xi8> type
+  MemRefLayoutAttrInterface layout = {};
+  auto memrefType0sz = MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());  
+
+  // Step 2: Generate a global symbol or existing for the dynamic shared
+  // memory with memref<0xi8> type
+  LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
+      rewriter, op, getTypeConverter(), memrefType0sz ,alignmentBit);
+  assert(shmemOp && "cannot find module op or failed generating global op");
+
+  // Step 3. Get address of the global symbol
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(op);
+  auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
+  Type baseType = basePtr->getResultTypes().front();
+
+  // Step 4. Generate GEP using offsets
+  SmallVector<LLVM::GEPArg> gepArgs = {0};
+  Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
+                                                basePtr, gepArgs);
+  // Step 5. Create a memref descriptor
+  SmallVector<Value> shape, strides;
+  Value sizeBytes;
+  getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
+                           sizeBytes);
+  auto memRefDescriptor = this->createMemRefDescriptor(
+      loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
+
+  // Step 5. Replace the op with memref descriptor
+  rewriter.replaceOp(op, {memRefDescriptor});
+  return success();
+}
+
 void mlir::populateGpuMemorySpaceAttributeConversions(
     TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
   typeConverter.addTypeAttributeConversion(
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index bd90286494d8035..a77db4a036bad3f 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,27 @@
 
 namespace mlir {
 
+/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
+/// create a 0-sized global array symbol similar as LLVM expects. It constructs
+/// a memref descriptor with these values and return it.
+struct GPUDynamicSharedMemoryOpLowering
+    : public ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp> {
+  using ConvertOpToLLVMPattern<
+      gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
+  GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
+                                   unsigned alignmentBit = 0)
+      : ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
+        alignmentBit(alignmentBit) {}
+
+  LogicalResult
+  matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+
+private:
+  // Alignment bit
+  unsigned alignmentBit;
+};
+
 struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
   GPUFuncOpLowering(const LLVMTypeConverter &converter,
                     unsigned allocaAddrSpace, unsigned workgroupAddrSpace,
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 935e3d2a4095003..86a77f557cb9579 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -325,6 +325,9 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
            GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
           converter);
 
+  patterns.add<GPUDynamicSharedMemoryOpLowering>(
+      converter, NVVM::kSharedMemoryAlignmentBit);
+
   // Explicitly drop memory space when lowering private memory
   // attributions since NVVM models it as `alloca`s in the default
   // memory space and does not support `alloca`s with addrspace(5).
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5eb2cadc884e151..3216e82147da907 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -164,17 +164,20 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
 // GPUDialect
 //===----------------------------------------------------------------------===//
 
-/// GPU memory space identifiers.
-enum GPUMemorySpace {
-  /// Generic memory space identifier.
-  kGenericMemorySpace = 0,
-
-  /// Global memory space identifier.
-  kGlobalMemorySpace = 1,
+bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intAttr.getInt() == GPUMemorySpace::kSharedMemorySpace;
+  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuAttr.getValue() == getWorkgroupAddressSpace();
+  return false;
+}
 
-  /// Shared memory space identifier.
-  kSharedMemorySpace = 3
-};
+bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
+  Attribute memorySpace = type.getMemorySpace();
+  return isWorkgroupMemoryAddressSpace(memorySpace);
+}
 
 bool GPUDialect::isKernel(Operation *op) {
   UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
@@ -2024,6 +2027,24 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// DynamicSharedMemoryOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult gpu::DynamicSharedMemoryOp::verify() {
+  MemRefType memrefType = getResultMemref().getType();
+  // Check address space
+  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
+    return emitOpError() << "Address space must be "
+                         << gpu::AddressSpaceAttr::getMnemonic() << "<"
+                         << stringifyEnum(gpu::AddressSpace::Workgroup)
+                         << "> or " << int(GPUMemorySpace::kSharedMemorySpace);
+  }
+  if(memrefType.hasStaticShape()) 
+    return emitOpError() << "result memref type must be memref<?xi8>";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // GPU target options
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
new file mode 100644
index 000000000000000..d3ca12b7131c469
--- /dev/null
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
+
+gpu.module @modules {
+  // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+  
+  // CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
+  // CHECK-SAME: %[[arg0:.+]]: i64)
+  gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {    
+    %c1 = arith.constant 1 : index
+    %c100 = arith.constant 100 : index
+    %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+
+    %0 = memref.view %shmem[%c100][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+    "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+    
+// CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
+// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_0 : !llvm.ptr<3>
+// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][100] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S11:.+]] = llvm.insertvalue %[[S2]], %[[S10]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S12:.+]] = llvm.insertvalue %[[S0]], %[[S11]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %[[S13]] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+    gpu.return
+  }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index c8c0b7d24bc3ab2..768289819cc0e09 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -640,3 +640,52 @@ module {
   // expected-error @+1 {{'gpu.binary' op attribute 'offloadingHandler' failed to satisfy constraint: any attribute with the `OffloadingTranslationAttrTrait` trait.}}
   gpu.binary @binary <1> [#gpu.object<#nvvm.target, "">]
 }
+
+// -----
+
+func.func @main() {
+  %shmemSize = arith.constant 10000 : i32
+  %c1 = arith.constant 1 : index
+  gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
+             dynamic_shared_memory_size %shmemSize
+  {
+    // expected-error @+1 {{op Address space must be address_space<workgroup> or 3}}
+    %0 = gpu.dynamic_shared_memory : memref<?xi8>  
+    gpu.terminator
+  }
+  return
+}
+
+
+// -----
+
+func.func @main() {
+  %shmemSize = arith.constant 8192 : i32
+  %c1 = arith.constant 1 : index
+  gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
+             dynamic_shared_memory_size %shmemSize
+  {
+    // expected-error @+1 {{result memref type must be memref<?xi8>}}
+    %0 = gpu.dynamic_shared_memory : memref<1xi8,3>  
+    gpu.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @main(%arg0 : index) {
+  %shmemSize = arith.constant 8192 : i32
+  %c1 = arith.constant 1 : index
+  gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+             threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
+             dynamic_shared_memory_size %shmemSize
+  {
+    // expected-error @+1 {{op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, 3>}}
+    %0 = gpu.dynamic_shared_memory : memref<?xf32,3>  
+    gpu.terminator
+  }
+  return
+}

>From f7babc582ea39a01e17617790652f76939757b39 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 16:16:12 +0100
Subject: [PATCH 02/12] add nl

---
 mlir/test/Dialect/GPU/dynamic-shared-memory.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index d3ca12b7131c469..a2706fa6bde7ae9 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -32,4 +32,4 @@ gpu.module @modules {
 
     gpu.return
   }
-}
\ No newline at end of file
+}

>From 8924447541382bacd05da1b9a207e9bb196409f6 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 16:18:39 +0100
Subject: [PATCH 03/12] remove todo

---
 mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index fbea498ee27caa7..aaee407166f581a 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -592,7 +592,7 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
   auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
       typeConverter->convertType(memrefType.getElementType()), 0);
   std::string name = std::string(llvm::formatv("{0}_{1}", prefix, index));
-  // TODO: better alignment calculation
+  
   uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
   return rewriter.create<LLVM::GlobalOp>(
       funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,

>From badeaa3e0c9375699eaa104310a3ba1fafbffe1d Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 16:21:38 +0100
Subject: [PATCH 04/12] improve the test

---
 .../Dialect/GPU/dynamic-shared-memory.mlir    | 21 +++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index a2706fa6bde7ae9..04bacb16140ad97 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -7,11 +7,16 @@ gpu.module @modules {
   // CHECK-SAME: %[[arg0:.+]]: i64)
   gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {    
     %c1 = arith.constant 1 : index
-    %c100 = arith.constant 100 : index
+    %c8192 = arith.constant 8192 : index
+    %c16384 = arith.constant 16384 : index
     %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+    %shmem2 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
 
-    %0 = memref.view %shmem[%c100][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+    %0 = memref.view %shmem[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
     "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+    %1 = memref.view %shmem[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+    "test.use.shared.memory"(%1) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
     
 // CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
 // CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
@@ -20,7 +25,7 @@ gpu.module @modules {
 // CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_0 : !llvm.ptr<3>
 // CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
-// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][100] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
 // CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
 // CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
 // CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
@@ -29,7 +34,15 @@ gpu.module @modules {
 // CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
 // CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %[[S13]] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
 // CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
-
+// CHECK: %[[S15:.+]] = llvm.getelementptr %4[16384] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S16:.+]] = llvm.insertvalue %[[S15]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S17:.+]] = llvm.insertvalue %[[S3]], %[[S16]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S18:.+]] = llvm.insertvalue %[[S1]], %[[S17]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S19:.+]] = llvm.insertvalue %[[S2]], %[[S18]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S20:.+]] = llvm.insertvalue %[[S0]], %[[S19]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S21:.+]] = llvm.insertvalue %[[S1]], %[[S20]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S22:.+]] = builtin.unrealized_conversion_cast %[[S21]] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S22]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
     gpu.return
   }
 }

>From 91221dfb84ebad8326a30c4c8a369a6af9bdef1b Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 17:01:36 +0100
Subject: [PATCH 05/12] fix format

---
 .../lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 17 ++++++++++-------
 1 file changed, 10 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index aaee407166f581a..4cc9e1c353189c6 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -557,9 +557,11 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
 
 /// Generates a symbol with 0-sized array type for dynamic shared memory usage,
 /// or uses existing symbol.
-LLVM::GlobalOp getDynamicSharedMemorySymbol(
-    ConversionPatternRewriter &rewriter, gpu::DynamicSharedMemoryOp op,
-    const LLVMTypeConverter *typeConverter, MemRefType memrefType, unsigned alignmentBit) {
+LLVM::GlobalOp
+getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
+                             gpu::DynamicSharedMemoryOp op,
+                             const LLVMTypeConverter *typeConverter,
+                             MemRefType memrefType, unsigned alignmentBit) {
   std::optional<LLVM::GlobalOp> existingGlobalOp;
 
   LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -592,7 +594,7 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
   auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
       typeConverter->convertType(memrefType.getElementType()), 0);
   std::string name = std::string(llvm::formatv("{0}_{1}", prefix, index));
-  
+
   uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
   return rewriter.create<LLVM::GlobalOp>(
       funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
@@ -607,15 +609,16 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
   MemRefType memrefType = op.getResultMemref().getType();
   auto elementType = typeConverter->convertType(memrefType.getElementType());
   assert(memrefType && "memref is not valid");
-  
+
   // Step 1: Generate a memref<0xi8> type
   MemRefLayoutAttrInterface layout = {};
-  auto memrefType0sz = MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());  
+  auto memrefType0sz =
+      MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
 
   // Step 2: Generate a global symbol or existing for the dynamic shared
   // memory with memref<0xi8> type
   LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
-      rewriter, op, getTypeConverter(), memrefType0sz ,alignmentBit);
+      rewriter, op, getTypeConverter(), memrefType0sz, alignmentBit);
   assert(shmemOp && "cannot find module op or failed generating global op");
 
   // Step 3. Get address of the global symbol

>From 98ad1a2cee6c7febb19b80c488dcbdf81316cb20 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 17:20:15 +0100
Subject: [PATCH 06/12] fix format

---
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 3216e82147da907..113f581e6522234 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2040,7 +2040,7 @@ LogicalResult gpu::DynamicSharedMemoryOp::verify() {
                          << stringifyEnum(gpu::AddressSpace::Workgroup)
                          << "> or " << int(GPUMemorySpace::kSharedMemorySpace);
   }
-  if(memrefType.hasStaticShape()) 
+  if (memrefType.hasStaticShape())
     return emitOpError() << "result memref type must be memref<?xi8>";
   return success();
 }

>From d8860f5f340da1660be4427bbaf18361b5267e7a Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 7 Nov 2023 18:55:52 +0100
Subject: [PATCH 07/12] address @ftynse comments

---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    | 16 +++++++-------
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 21 ++++++++-----------
 .../Dialect/GPU/dynamic-shared-memory.mlir    |  4 ++--
 mlir/test/Dialect/GPU/invalid.mlir            |  6 +++---
 4 files changed, 23 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index f3a37c62d3a7465..c6989f4b9bbbd6b 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -439,18 +439,20 @@ def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [] > {
   let description = [{
     This operation provides a memref pointer to the start of dynamic shared 
     memory, often referred to as workgroup memory. It's important to note that
-     this dynamic shared memory needs to be allocated at kernel launch. One can 
-     conveniently utilize `the dynamic_shared_memory_size` parameter of 
-     `gpu.launch` for this purpose.
+    this dynamic shared memory needs to be allocated at kernel launch. One can 
+    conveniently utilize `the dynamic_shared_memory_size` parameter of 
+    `gpu.launch` for this purpose.
    
     Examples: 
     ```mlir        
-    %0 = gpu.dynamic.shared.memory : memref<?xi8, 3>
-    %1 = memref.view %0[%c8192][] : memref<?xi8, 3> to memref<32x64xf32, #gpu.address_space<workgroup>>
-    %2 = memref.view %0[%c16384][] : memref<?xi8, 3> to memref<32x64xf32, #gpu.address_space<workgroup>>
+    %0 = gpu.dynamic.shared.memory : memref<?xi8, #gpu.address_space<workgroup>>
+    %1 = memref.view %0[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> 
+                            to memref<32x64xf32, #gpu.address_space<workgroup>>
+    %2 = memref.view %0[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> 
+                            to memref<32x64xf32, #gpu.address_space<workgroup>>
     ```
   }];  
-  let arguments = (ins );
+  let arguments = (ins);
   let results = (outs Arg<MemRefRankOf<[I8], [1]>>:$resultMemref);
   let assemblyFormat = [{ attr-dict `:` type($resultMemref) }];
   let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 4cc9e1c353189c6..5049fdff3a29c33 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -562,7 +562,7 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
                              gpu::DynamicSharedMemoryOp op,
                              const LLVMTypeConverter *typeConverter,
                              MemRefType memrefType, unsigned alignmentBit) {
-  std::optional<LLVM::GlobalOp> existingGlobalOp;
+  LLVM::GlobalOp existingGlobalOp;
 
   LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
   assert(funcOp && "cannot find llvm.func op");
@@ -570,22 +570,20 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
   gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
   assert(moduleOp && "cannot find gpu.module op");
 
-  // Use already generated global op if it exists
-  int index = 0;
-  std::string prefix = llvm::formatv("__shmem_{0}", funcOp.getSymName());
+  uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+  // Use already generated global op if it exists.
   moduleOp->walk([&](LLVM::GlobalOp globalOp) {
     if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
-      if (arrayType.getNumElements() == 0) {
+      if (arrayType.getNumElements() == 0 &&
+          globalOp.getAlignment().value_or(0) == alignmentByte) {
         existingGlobalOp = globalOp;
         return WalkResult::interrupt();
       }
     }
-    if (globalOp.getSymName().startswith(prefix))
-      index++;
     return WalkResult::advance();
   });
-  if (existingGlobalOp.has_value())
-    return existingGlobalOp.value();
+  if (existingGlobalOp)
+    return existingGlobalOp;
 
   // Generate a new global op
   OpBuilder::InsertionGuard guard(rewriter);
@@ -593,9 +591,8 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
 
   auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
       typeConverter->convertType(memrefType.getElementType()), 0);
-  std::string name = std::string(llvm::formatv("{0}_{1}", prefix, index));
+  std::string name = llvm::formatv("__shmem_{0}", funcOp.getSymName());
 
-  uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
   return rewriter.create<LLVM::GlobalOp>(
       funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
       LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignmentByte,
@@ -607,7 +604,7 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
     ConversionPatternRewriter &rewriter) const {
   Location loc = op.getLoc();
   MemRefType memrefType = op.getResultMemref().getType();
-  auto elementType = typeConverter->convertType(memrefType.getElementType());
+  Type elementType = typeConverter->convertType(memrefType.getElementType());
   assert(memrefType && "memref is not valid");
 
   // Step 1: Generate a memref<0xi8> type
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index 04bacb16140ad97..f300481ba178d74 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
 
 gpu.module @modules {
-  // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+  // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
   
   // CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
   // CHECK-SAME: %[[arg0:.+]]: i64)
@@ -22,7 +22,7 @@ gpu.module @modules {
 // CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
 // CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_0 : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel : !llvm.ptr<3>
 // CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
 // CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 768289819cc0e09..fd01f08ce552a2a 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -650,7 +650,7 @@ func.func @main() {
              threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
              dynamic_shared_memory_size %shmemSize
   {
-    // expected-error @+1 {{op Address space must be address_space<workgroup> or 3}}
+    // expected-error @below {{op Address space must be address_space<workgroup> or 3}}
     %0 = gpu.dynamic_shared_memory : memref<?xi8>  
     gpu.terminator
   }
@@ -667,7 +667,7 @@ func.func @main() {
              threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
              dynamic_shared_memory_size %shmemSize
   {
-    // expected-error @+1 {{result memref type must be memref<?xi8>}}
+    // expected-error @below {{result memref type must be memref<?xi8>}}
     %0 = gpu.dynamic_shared_memory : memref<1xi8,3>  
     gpu.terminator
   }
@@ -683,7 +683,7 @@ func.func @main(%arg0 : index) {
              threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) 
              dynamic_shared_memory_size %shmemSize
   {
-    // expected-error @+1 {{op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, 3>}}
+    // expected-error @below {{op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, 3>}}
     %0 = gpu.dynamic_shared_memory : memref<?xf32,3>  
     gpu.terminator
   }

>From 371cddc4d970711433dea3f67af7a6871af332b5 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 9 Nov 2023 15:10:01 +0100
Subject: [PATCH 08/12] find unique name and test

---
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 25 +++++++++++++------
 .../Dialect/GPU/dynamic-shared-memory.mlir    |  7 +++---
 2 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 5049fdff3a29c33..a31179f495dc261 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -14,6 +14,8 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Visitors.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 
@@ -562,8 +564,6 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
                              gpu::DynamicSharedMemoryOp op,
                              const LLVMTypeConverter *typeConverter,
                              MemRefType memrefType, unsigned alignmentBit) {
-  LLVM::GlobalOp existingGlobalOp;
-
   LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
   assert(funcOp && "cannot find llvm.func op");
 
@@ -571,31 +571,42 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
   assert(moduleOp && "cannot find gpu.module op");
 
   uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
-  // Use already generated global op if it exists.
+
+  LLVM::GlobalOp existingGlobalOp;
   moduleOp->walk([&](LLVM::GlobalOp globalOp) {
     if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
       if (arrayType.getNumElements() == 0 &&
           globalOp.getAlignment().value_or(0) == alignmentByte) {
         existingGlobalOp = globalOp;
-        return WalkResult::interrupt();
       }
     }
-    return WalkResult::advance();
   });
   if (existingGlobalOp)
     return existingGlobalOp;
 
+  // Find unique name
+  int index = 0;
+  std::string symName, name = llvm::formatv("__shmem_{0}", funcOp.getSymName());
+  WalkResult walkResult;
+  do {
+    symName = llvm::formatv("{0}_{1}", name, index++);
+    walkResult = moduleOp->walk([&](LLVM::GlobalOp globalOp) {
+      if (globalOp.getSymName() == symName)
+        return WalkResult::interrupt();
+      return WalkResult::advance();
+    });
+  } while (walkResult.wasInterrupted());
+
   // Generate a new global op
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(&moduleOp.front());
 
   auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
       typeConverter->convertType(memrefType.getElementType()), 0);
-  std::string name = llvm::formatv("__shmem_{0}", funcOp.getSymName());
 
   return rewriter.create<LLVM::GlobalOp>(
       funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
-      LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignmentByte,
+      LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
       mlir::gpu::GPUMemorySpace::kSharedMemorySpace);
 }
 
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index f300481ba178d74..308b5fb42e93e1e 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -1,8 +1,9 @@
 // RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
 
 gpu.module @modules {
-  // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
-  
+  // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_1() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+  llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
+  llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>  
   // CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
   // CHECK-SAME: %[[arg0:.+]]: i64)
   gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {    
@@ -22,7 +23,7 @@ gpu.module @modules {
 // CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
 // CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_1 : !llvm.ptr<3>
 // CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
 // CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8

>From f458af8559e343b5c2fae425d9d9e3a6cbb1f168 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 9 Nov 2023 16:24:30 +0100
Subject: [PATCH 09/12] address comments

---
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 57 ++++++++++---------
 .../lib/Conversion/GPUCommon/GPUOpsLowering.h |  5 +-
 .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp        |  3 +-
 .../Dialect/GPU/dynamic-shared-memory.mlir    | 30 +++++++++-
 4 files changed, 64 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index a31179f495dc261..f181c8418d9a224 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -559,45 +559,47 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
 
 /// Generates a symbol with 0-sized array type for dynamic shared memory usage,
 /// or uses existing symbol.
-LLVM::GlobalOp
-getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
-                             gpu::DynamicSharedMemoryOp op,
-                             const LLVMTypeConverter *typeConverter,
-                             MemRefType memrefType, unsigned alignmentBit) {
+LLVM::GlobalOp getDynamicSharedMemorySymbol(
+    ConversionPatternRewriter &rewriter, gpu::DynamicSharedMemoryOp op,
+    const LLVMTypeConverter *typeConverter, MemRefType memrefType,
+    unsigned alignmentBit, unsigned addressSpace) {
   LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
   assert(funcOp && "cannot find llvm.func op");
 
   gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
   assert(moduleOp && "cannot find gpu.module op");
 
+  // Step 1. Return existing global op if it exists
   uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
-
-  LLVM::GlobalOp existingGlobalOp;
-  moduleOp->walk([&](LLVM::GlobalOp globalOp) {
-    if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
-      if (arrayType.getNumElements() == 0 &&
-          globalOp.getAlignment().value_or(0) == alignmentByte) {
-        existingGlobalOp = globalOp;
+  for (auto &innerOp : moduleOp->getRegions().front().front().getOperations()) {
+    if (auto globalOp = dyn_cast<LLVM::GlobalOp>(innerOp)) {
+      if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+        if (globalOp.getAddrSpace() == addressSpace &&
+            arrayType.getNumElements() == 0 &&
+            globalOp.getAlignment().value_or(0) == alignmentByte) {
+          return globalOp;
+        }
       }
     }
-  });
-  if (existingGlobalOp)
-    return existingGlobalOp;
+  }
 
-  // Find unique name
+  // Step 2. Find a unique symbol name
   int index = 0;
   std::string symName, name = llvm::formatv("__shmem_{0}", funcOp.getSymName());
-  WalkResult walkResult;
+  bool nameExist;
   do {
+    nameExist = false;
     symName = llvm::formatv("{0}_{1}", name, index++);
-    walkResult = moduleOp->walk([&](LLVM::GlobalOp globalOp) {
-      if (globalOp.getSymName() == symName)
-        return WalkResult::interrupt();
-      return WalkResult::advance();
-    });
-  } while (walkResult.wasInterrupted());
+    for (auto &innerOp :
+         moduleOp->getRegions().front().front().getOperations()) {
+      if (auto globalOp = dyn_cast<LLVM::GlobalOp>(innerOp)) {
+        if (globalOp.getSymName() == symName)
+          nameExist = true;
+      }
+    }
+  } while (nameExist);
 
-  // Generate a new global op
+  // Step 3. Generate a global op
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(&moduleOp.front());
 
@@ -615,8 +617,8 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
     ConversionPatternRewriter &rewriter) const {
   Location loc = op.getLoc();
   MemRefType memrefType = op.getResultMemref().getType();
-  Type elementType = typeConverter->convertType(memrefType.getElementType());
   assert(memrefType && "memref is not valid");
+  Type elementType = typeConverter->convertType(memrefType.getElementType());
 
   // Step 1: Generate a memref<0xi8> type
   MemRefLayoutAttrInterface layout = {};
@@ -625,8 +627,9 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
 
   // Step 2: Generate a global symbol or existing for the dynamic shared
   // memory with memref<0xi8> type
-  LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
-      rewriter, op, getTypeConverter(), memrefType0sz, alignmentBit);
+  LLVM::GlobalOp shmemOp =
+      getDynamicSharedMemorySymbol(rewriter, op, getTypeConverter(),
+                                   memrefType0sz, alignmentBit, addressSpace);
   assert(shmemOp && "cannot find module op or failed generating global op");
 
   // Step 3. Get address of the global symbol
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index a77db4a036bad3f..0f16b69a3105608 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -22,9 +22,10 @@ struct GPUDynamicSharedMemoryOpLowering
   using ConvertOpToLLVMPattern<
       gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
   GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
+                                   unsigned addressSpace,
                                    unsigned alignmentBit = 0)
       : ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
-        alignmentBit(alignmentBit) {}
+        alignmentBit(alignmentBit), addressSpace(addressSpace) {}
 
   LogicalResult
   matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
@@ -33,6 +34,8 @@ struct GPUDynamicSharedMemoryOpLowering
 private:
   // Alignment bit
   unsigned alignmentBit;
+  // Address space of the shared memory
+  unsigned addressSpace;
 };
 
 struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 86a77f557cb9579..52f73d2432a3f2f 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -326,7 +326,8 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
           converter);
 
   patterns.add<GPUDynamicSharedMemoryOpLowering>(
-      converter, NVVM::kSharedMemoryAlignmentBit);
+      converter, NVVM::NVVMMemorySpace::kSharedMemorySpace,
+      NVVM::kSharedMemoryAlignmentBit);
 
   // Explicitly drop memory space when lowering private memory
   // attributions since NVVM models it as `alloca`s in the default
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index 308b5fb42e93e1e..d17a97cc9dd6cc4 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -1,9 +1,10 @@
 // RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
 
 gpu.module @modules {
-  // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_1() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+  // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_2() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
   llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
   llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>  
+  llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_1() {alignment = 16 : i64} : !llvm.array<0 x i8>  
   // CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
   // CHECK-SAME: %[[arg0:.+]]: i64)
   gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {    
@@ -23,7 +24,7 @@ gpu.module @modules {
 // CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
 // CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_1 : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
 // CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
 // CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
@@ -46,4 +47,29 @@ gpu.module @modules {
 // CHECK: "test.use.shared.memory"(%[[S22]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
     gpu.return
   }
+
+  gpu.func @device_function()  {    
+    %c8192 = arith.constant 8192 : index
+    %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+    %0 = memref.view %shmem[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+    "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+// CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
+// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
+// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S11:.+]] = llvm.insertvalue %[[S2]], %[[S10]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S12:.+]] = llvm.insertvalue %[[S0]], %[[S11]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %13 : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+    gpu.return
+  }
 }

>From fa45fa90dea85fecc8921eea3a2da5d878ab2741 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 10 Nov 2023 12:52:10 +0100
Subject: [PATCH 10/12] test func.func

---
 .../Dialect/GPU/dynamic-shared-memory.mlir    | 29 ++++++++++++++++++-
 1 file changed, 28 insertions(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index d17a97cc9dd6cc4..c54ed3a5f30f52b 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -48,7 +48,8 @@ gpu.module @modules {
     gpu.return
   }
 
-  gpu.func @device_function()  {    
+// CHECK-LABEL: llvm.func @gpu_device_function
+  gpu.func @gpu_device_function()  {    
     %c8192 = arith.constant 8192 : index
     %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
     %0 = memref.view %shmem[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
@@ -72,4 +73,30 @@ gpu.module @modules {
 
     gpu.return
   }
+
+// CHECK-LABEL: llvm.func @func_device_function
+  func.func @func_device_function()  {    
+    %c8192 = arith.constant 8192 : index
+    %shmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
+    %0 = memref.view %shmem[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<32x64xf32, #gpu.address_space<workgroup>>
+    "test.use.shared.memory"(%0) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+// CHECK: %[[S0:.+]] = llvm.mlir.constant(32 : index) : i64
+// CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
+// CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
+// CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
+// CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
+// CHECK: %[[S8:.+]] = llvm.insertvalue %[[S7]], %[[S6]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S9:.+]] = llvm.insertvalue %[[S3]], %[[S8]][2] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S10:.+]] = llvm.insertvalue %[[S1]], %[[S9]][3, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S11:.+]] = llvm.insertvalue %[[S2]], %[[S10]][4, 1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S12:.+]] = llvm.insertvalue %[[S0]], %[[S11]][3, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S13:.+]] = llvm.insertvalue %[[S1]], %[[S12]][4, 0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
+// CHECK: %[[S14:.+]] = builtin.unrealized_conversion_cast %13 : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> to memref<32x64xf32, #gpu.address_space<workgroup>>
+// CHECK: "test.use.shared.memory"(%[[S14]]) : (memref<32x64xf32, #gpu.address_space<workgroup>>) -> ()
+
+    func.return
+  }
 }

>From 75c29e7f4921a1cf65869af297178fd3dfea28a6 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 10 Nov 2023 14:44:42 +0100
Subject: [PATCH 11/12] address @ftynse comments

---
 mlir/include/mlir/IR/SymbolTable.h            | 18 ++++++++++++
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 28 +++++++++----------
 mlir/lib/IR/SymbolTable.cpp                   | 20 ++++++-------
 .../Dialect/GPU/dynamic-shared-memory.mlir    | 14 +++++-----
 4 files changed, 46 insertions(+), 34 deletions(-)

diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h
index 7f21f22eba951e3..597c6a9a1d89108 100644
--- a/mlir/include/mlir/IR/SymbolTable.h
+++ b/mlir/include/mlir/IR/SymbolTable.h
@@ -103,6 +103,24 @@ class SymbolTable {
     Nested,
   };
 
+  /// Generate a unique symbol name. Iteratively increase uniquingCounter
+  /// and use it as a suffix for symbol names until uniqueChecker does not
+  /// detect any conflict.
+  template <unsigned N, typename UniqueChecker>
+  static SmallString<N> generateSymbolName(StringRef name,
+                                           UniqueChecker uniqueChecker,
+                                           unsigned &uniquingCounter) {
+    SmallString<N> nameBuffer(name);
+    unsigned originalLength = nameBuffer.size();
+    do {
+      nameBuffer.resize(originalLength);
+      nameBuffer += '_';
+      nameBuffer += std::to_string(uniquingCounter++);
+    } while (uniqueChecker(nameBuffer));
+
+    return nameBuffer;
+  }
+
   /// Returns the name of the given symbol operation, aborting if no symbol is
   /// present.
   static StringAttr getSymbolName(Operation *symbol);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index f181c8418d9a224..7be79e44e0f6945 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Visitors.h"
 #include "llvm/ADT/SmallVectorExtras.h"
+#include "llvm/ADT/StringSet.h"
 #include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
@@ -569,10 +570,14 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
   gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
   assert(moduleOp && "cannot find gpu.module op");
 
-  // Step 1. Return existing global op if it exists
   uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+
+  // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
+  // LLVM::GlobalOp is suitable for shared memory, return it.
+  llvm::StringSet<> existingGlobalNames;
   for (auto &innerOp : moduleOp->getRegions().front().front().getOperations()) {
     if (auto globalOp = dyn_cast<LLVM::GlobalOp>(innerOp)) {
+      existingGlobalNames.insert(globalOp.getSymName());
       if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
         if (globalOp.getAddrSpace() == addressSpace &&
             arrayType.getNumElements() == 0 &&
@@ -584,20 +589,13 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
   }
 
   // Step 2. Find a unique symbol name
-  int index = 0;
-  std::string symName, name = llvm::formatv("__shmem_{0}", funcOp.getSymName());
-  bool nameExist;
-  do {
-    nameExist = false;
-    symName = llvm::formatv("{0}_{1}", name, index++);
-    for (auto &innerOp :
-         moduleOp->getRegions().front().front().getOperations()) {
-      if (auto globalOp = dyn_cast<LLVM::GlobalOp>(innerOp)) {
-        if (globalOp.getSymName() == symName)
-          nameExist = true;
-      }
-    }
-  } while (nameExist);
+  unsigned uniquingCounter = 0;
+  SmallString<128> symName = SymbolTable::generateSymbolName<128>(
+      "__dynamic_shmem_",
+      [&](StringRef candidate) {
+        return existingGlobalNames.contains(candidate);
+      },
+      uniquingCounter);
 
   // Step 3. Generate a global op
   OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 7180ea432ea057d..e83d19553d62ce8 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -200,20 +200,16 @@ StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
   // If the symbol was already in the table, also return.
   if (symbolTable.lookup(name) == symbol)
     return name;
-  // If a conflict was detected, then the symbol will not have been added to
-  // the symbol table. Try suffixes until we get to a unique name that works.
-  SmallString<128> nameBuffer(name.getValue());
-  unsigned originalLength = nameBuffer.size();
 
   MLIRContext *context = symbol->getContext();
-
-  // Iteratively try suffixes until we find one that isn't used.
-  do {
-    nameBuffer.resize(originalLength);
-    nameBuffer += '_';
-    nameBuffer += std::to_string(uniquingCounter++);
-  } while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
-                .second);
+  SmallString<128> nameBuffer = generateSymbolName<128>(
+      name.getValue(),
+      [&](StringRef candidate) {
+        return !symbolTable
+                    .insert({StringAttr::get(context, candidate), symbol})
+                    .second;
+      },
+      uniquingCounter);
   setSymbolName(symbol, nameBuffer);
   return getSymbolName(symbol);
 }
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index c54ed3a5f30f52b..fb45faaa712f7a9 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-opt %s -convert-gpu-to-nvvm -cse -canonicalize | FileCheck %s
 
 gpu.module @modules {
-  // CHECK: llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_2() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
-  llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
-  llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_0() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>  
-  llvm.mlir.global internal @__shmem_dynamic_shared_memory_kernel_1() {alignment = 16 : i64} : !llvm.array<0 x i8>  
+  // CHECK: llvm.mlir.global internal @__dynamic_shmem__3() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
+  llvm.mlir.global internal @__dynamic_shmem__0() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>
+  llvm.mlir.global internal @__dynamic_shmem__1() {addr_space = 3 : i32, alignment = 4 : i64} : !llvm.array<0 x i8>  
+  llvm.mlir.global internal @__dynamic_shmem__2() {alignment = 16 : i64} : !llvm.array<0 x i8>  
   // CHECK-LABEL: llvm.func @dynamic_shared_memory_kernel(
   // CHECK-SAME: %[[arg0:.+]]: i64)
   gpu.func @dynamic_shared_memory_kernel(%d : index) kernel attributes {gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 1, 1, 1>} {    
@@ -24,7 +24,7 @@ gpu.module @modules {
 // CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
 // CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__dynamic_shmem__3 : !llvm.ptr<3>
 // CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
 // CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
@@ -58,7 +58,7 @@ gpu.module @modules {
 // CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
 // CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__dynamic_shmem__3 : !llvm.ptr<3>
 // CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
 // CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8
@@ -84,7 +84,7 @@ gpu.module @modules {
 // CHECK: %[[S1:.+]] = llvm.mlir.constant(64 : index) : i64
 // CHECK: %[[S2:.+]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: %[[S3:.+]] = llvm.mlir.constant(0 : index) : i64
-// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__shmem_dynamic_shared_memory_kernel_2 : !llvm.ptr<3>
+// CHECK: %[[S4:.+]] = llvm.mlir.addressof @__dynamic_shmem__3 : !llvm.ptr<3>
 // CHECK: %[[S5:.+]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK: %[[S6:.+]] = llvm.insertvalue %[[S4]], %[[S5]][0] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> 
 // CHECK: %[[S7:.+]] = llvm.getelementptr %[[S4]][8192] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8

>From dbe76be3ea4c82d685bb7823ff972ca491b2e015 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 10 Nov 2023 15:05:50 +0100
Subject: [PATCH 12/12] use `getOps`

---
 .../lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 7be79e44e0f6945..00cd8010015da81 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -575,15 +575,14 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol(
   // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
   // LLVM::GlobalOp is suitable for shared memory, return it.
   llvm::StringSet<> existingGlobalNames;
-  for (auto &innerOp : moduleOp->getRegions().front().front().getOperations()) {
-    if (auto globalOp = dyn_cast<LLVM::GlobalOp>(innerOp)) {
-      existingGlobalNames.insert(globalOp.getSymName());
-      if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
-        if (globalOp.getAddrSpace() == addressSpace &&
-            arrayType.getNumElements() == 0 &&
-            globalOp.getAlignment().value_or(0) == alignmentByte) {
-          return globalOp;
-        }
+  for (auto globalOp :
+       moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) {
+    existingGlobalNames.insert(globalOp.getSymName());
+    if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+      if (globalOp.getAddrSpace() == addressSpace &&
+          arrayType.getNumElements() == 0 &&
+          globalOp.getAlignment().value_or(0) == alignmentByte) {
+        return globalOp;
       }
     }
   }



More information about the Mlir-commits mailing list