[Mlir-commits] [mlir] [mlir][gpu] Introduce `gpu.dynamic.shared.memory` Op (PR #71516)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 7 03:06:48 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

While the `gpu.launch` Op allows setting the size via the `dynamic_shared_memory_size` argument, accessing the dynamic shared memory is very convoluted. This PR implements the proposed Op, `gpu.dynamic.shared.memory` that aims to simplify the utilization of dynamic shared memory.

RFC: https://discourse.llvm.org/t/rfc-simplifying-dynamic-shared-memory-access-in-gpu/

**Proposal from RFC**
This PR `gpu.dynamic.shared.memory` Op to use dynamic shared memory feature efficiently. It is is a powerful feature that enables the allocation of shared memory at runtime with the kernel launch on the host. Afterwards, the memory can be accessed directly from the device. I believe similar story exists for AMDGPU.

**New Op Features**

- No more 0-Sized Global Symbol Generation: The lowering will hide 1st and 3rd steps.
- Simplified Shared Memory Access: No need for reinterpret_cast or subview. The offset argument will be sufficient.
- Compile-time Bound Check: The Op verifier checks dynamic_shared_memory_size < offset if they are compile-time constants.
- Runtime-time Bound Check (TODO): We can add `{dynamicBoundCheck}` attribute that checks dynamic_shared_memory_size < offset on the runtime. This is optional and definitely adds overhead, but it can be beneficial for debugging.

**Current way Using Dynamic Shared Memory with MLIR**

Let me illustrate the challenges of using dynamic shared memory in MLIR with an example below. The process involves several steps:
- memref.global 0-sized array NVPTX expects
- dynamic_shared_memory_size Set the size of dynamic shared memory
- memref.get_global Access the global symbol
- reinterpret_cast and subview Many OPs for pointer arithmetic

```
// Step 1. Create 0-sized global symbol. Manually set the alignment
memref.global "private" @<!-- -->dynamicShmem  : memref<0xf16, 3> { alignment = 16 }
func.func @<!-- -->main() {
  // Step 2. Allocate shared memory
  gpu.launch blocks(...) threads(...)
    dynamic_shared_memory_size %c10000 {
    // Step 3. Access the global object
    %shmem = memref.get_global @<!-- -->dynamicShmem : memref<0xf16, 3>
    // Step 4. A sequence of `memref.reinterpret_cast` and `memref.subview` operations.
    %4 = memref.reinterpret_cast %shmem to offset: [0], sizes: [14, 64, 128],  strides: [8192,128,1] : memref<0xf16, 3> to memref<14x64x128xf16,3>
    %5 = memref.subview %4[7, 0, 0][7, 64, 128][1,1,1] : memref<14x64x128xf16,3> to memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3>
    %6 = memref.subview %5[2, 0, 0][1, 64, 128][1,1,1] : memref<7x64x128xf16, strided<[8192, 128, 1], offset: 57344>, 3> to memref<64x128xf16, strided<[128, 1], offset: 73728>, 3>
    %7 = memref.subview %6[0, 0][64, 64][1,1]  : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>
    %8 = memref.subview %6[32, 0][64, 64][1,1] : memref<64x128xf16, strided<[128, 1], offset: 73728>, 3> to memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>
    // Step.5 Use
    "test.use.shared.memory"(%7) : (memref<64x64xf16, strided<[128, 1], offset: 73728>, 3>) -> (index)
    "test.use.shared.memory"(%8) : (memref<64x64xf16, strided<[128, 1], offset: 77824>, 3>) -> (index)
    gpu.terminator
  }
```

Let’s write the program above with that:

```
func.func @<!-- -->main() {
    gpu.launch blocks(...) threads(...) dynamic_shared_memory_size %c10000 {
    	%i = arith.constant 18 : index
        // Step 1: Obtain shared memory directly
        %7 = gpu.dynamic.shared.memory [%i,0,0] : memref<64x64xf16, 3>
        %i2 = arith.addi %i, %c1
        %8 = gpu.dynamic.shared.memory [%i2,0,0] : memref<64x64xf16, 3>

        // Step 2: Utilize the shared memory
        "test.use.shared.memory"(%7) : (memref<64x64xf16, 3>) -> (index)
        "test.use.shared.memory"(%8) : (memref<64x64xf16, 3>) -> (index)
    }
}
```

---

Patch is 30.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71516.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUBase.td (+10) 
- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h (+13) 
- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUOps.td (+72-2) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h (+3) 
- (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+103) 
- (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h (+22) 
- (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+3) 
- (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+113-19) 
- (added) mlir/test/Dialect/GPU/dynamic-shared-memory.mlir (+64) 
- (modified) mlir/test/Dialect/GPU/invalid.mlir (+49) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 755c82d8b75c9c0..057b507c394e80f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -52,6 +52,16 @@ def GPU_Dialect : Dialect {
     /// Returns the numeric value used to identify the private memory address
     /// space.
     static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; }
+    
+    /// Return true if the given MemRefType has an integer address
+    /// space that matches the workgroup memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool hasWorkgroupMemoryAddressSpace(MemRefType type);
+
+    /// Return true if the given Attribute has an integer address
+    /// space that matches the workgroup memory address space or
+    /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`.
+    static bool isWorkgroupMemoryAddressSpace(Attribute memorySpace);  
   }];
 
   let dependentDialects = ["arith::ArithDialect"];
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 14a1fac5fd255f3..286856324950eb7 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -17,6 +17,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/DLTI/Traits.h"
 #include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -32,6 +33,18 @@
 namespace mlir {
 namespace gpu {
 
+/// GPU memory space identifiers.
+enum GPUMemorySpace {
+  /// Generic memory space identifier.
+  kGenericMemorySpace = 0,
+
+  /// Global memory space identifier.
+  kGlobalMemorySpace = 1,
+
+  /// Shared memory space identifier.
+  kSharedMemorySpace = 3
+};
+
 /// Utility class for the GPU dialect to represent triples of `Value`s
 /// accessible through `.x`, `.y`, and `.z` similarly to CUDA notation.
 struct KernelDim3 {
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 6375d35f4311295..eac5b0096a3e10c 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -433,6 +433,74 @@ def GPU_GPUFuncOp : GPU_Op<"func", [
   let hasVerifier = 1;
 }
 
+def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic.shared.memory", 
+                              [MemoryEffects<[MemWrite]>] > {
+  let summary = "Get the memref for dynamic shared memory";
+  
+  let description = [{
+    This operation returns shared memory, also referred to as workgroup memory,
+    using given offsets.
+
+    It is possible to use both constants and SSA values as offsets. 
+
+    If this operation is used within a `gpu.launch`, the verifier will make an
+    attempt to verify that the offsets fall within bounds by utilizing the 
+    `dynamic_shared_memory_size` argument of `gpu.launch` when the values are
+    compile-time constants. Otherwise, the verifier does not perform offset 
+    checks.
+    
+    Examples: 
+    ```mlir
+    // Constant value, offset = 32 * 64 * sizeof(f32) * 1
+    %0 = gpu.dynamic.shared.memory [1] : memref<32x64xf32, #gpu.address_space<workgroup>>
+
+    // Multi-dimensional constant values, offset = (32 * 64 * 1 + 8) * sizeof(f32)
+    %0 = gpu.dynamic.shared.memory [1, 0, 8] : memref<32x64xf32, #gpu.address_space<workgroup>>
+
+    // Multi-dimensional dynamic values, offset = (32 * 64 * %1) * sizeof(f32)
+    %0 = gpu.dynamic.shared.memory [%1, 0, 0] : (index) -> memref<32x32xf32>
+
+    // Multi-dimensional mixed values, offset = (32 * 64 * %1 + 8) * sizeof(f32)
+    %0 = gpu.dynamic.shared.memory [%1, 0, 8] : (index) -> memref<32x32xf32>
+    ```
+  }];  
+
+  let arguments = (ins 
+    Variadic<Index>:$dynamic_offsets,
+    DenseI64ArrayAttr:$static_offsets
+  );
+
+  let results = (outs Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$resultMemref);
+
+  let assemblyFormat = [{    
+    custom<DynamicIndexList>($dynamic_offsets, $static_offsets)
+    attr-dict 
+    `:` type($resultMemref)
+  }];
+
+   let builders = [
+    OpBuilder<(ins "Type":$memref, "int64_t":$offsets)>,
+    OpBuilder<(ins "Type":$memref, "OpFoldResult":$offsets)>,
+    OpBuilder<(ins "Type":$memref, "ArrayRef<int64_t>":$offsets)>,
+    OpBuilder<(ins "Type":$memref, "ArrayRef<OpFoldResult>":$offsets)>,
+  ];
+
+  let extraClassDeclaration = [{
+    /// Return a vector with all the static and dynamic offsets indices.
+    SmallVector<OpFoldResult> getMixedOffsets() {
+      OpBuilder builder(getContext());
+      return getMixedValues(getStaticOffsets(), getDynamicOffsets(), builder);
+    }
+
+    bool hasDynamicOffsets() {
+      auto dynPos = getDynamicOffsets();
+      return std::any_of(dynPos.begin(), dynPos.end(),
+                         [](Value operand) { return operand != nullptr; });
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 def LaunchIndx : AnyTypeOf<[Index, I32, I64]>;
 
 def GPU_LaunchFuncOp :GPU_Op<"launch_func", [
@@ -587,7 +655,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
     Arguments<(ins Variadic<GPU_AsyncToken>:$asyncDependencies,
                Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ,
                Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ,
-               Optional<I32>:$dynamicSharedMemorySize)>,
+               Optional<I32>:$dynamicSharedMemorySize,
+               OptionalAttr<I32Attr>:$guray)>,
     Results<(outs Optional<GPU_AsyncToken>:$asyncToken)> {
   let summary = "GPU kernel launch operation";
 
@@ -693,7 +762,8 @@ def GPU_LaunchOp : GPU_Op<"launch", [
       CArg<"Type", "nullptr">:$asyncTokenType,
       CArg<"ValueRange", "{}">:$asyncDependencies,
       CArg<"TypeRange", "{}">:$workgroupAttributions,
-      CArg<"TypeRange", "{}">:$privateAttributions)>
+      CArg<"TypeRange", "{}">:$privateAttributions,
+      CArg<"IntegerAttr", "IntegerAttr()">:$guray)>
   ];
 
   let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 8ff8f850a9c1858..08019e77ae6af8a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -27,6 +27,9 @@
 namespace mlir {
 namespace NVVM {
 
+// Shared memory has 128-bit alignment
+constexpr int kSharedMemoryAlignmentBit = 128;
+
 /// NVVM memory space identifiers.
 enum NVVMMemorySpace {
   /// Global memory space identifier.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6d2585aa30ab4c5..c8f809ee88c54d7 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -9,6 +9,7 @@
 #include "GPUOpsLowering.h"
 
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -554,6 +555,108 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
   return IntegerAttr::get(IntegerType::get(ctx, 64), space);
 }
 
+/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
+/// or uses existing symbol.
+LLVM::GlobalOp getDynamicSharedMemorySymbol(
+    ConversionPatternRewriter &rewriter, gpu::DynamicSharedMemoryOp op,
+    const LLVMTypeConverter *typeConverter, unsigned alignmentBit) {
+  std::optional<LLVM::GlobalOp> existingGlobalOp;
+
+  MemRefType memrefType = op.getResultMemref().getType();
+  assert(memrefType && memrefType.hasStaticShape() &&
+         "expected static shaped memref type");
+
+  LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
+  assert(funcOp && "cannot find llvm.func op");
+
+  gpu::GPUModuleOp moduleOp = funcOp->getParentOfType<gpu::GPUModuleOp>();
+  assert(moduleOp && "cannot find gpu.module op");
+
+  // Use already generated global op if it exists
+  int index = 0;
+  std::string prefix = llvm::formatv("__shmem_{0}", funcOp.getSymName());
+  moduleOp->walk([&](LLVM::GlobalOp globalOp) {
+    if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
+      if (arrayType.getNumElements() == 0) {
+        existingGlobalOp = globalOp;
+        return WalkResult::interrupt();
+      }
+    }
+    if (globalOp.getSymName().startswith(prefix))
+      index++;
+    return WalkResult::advance();
+  });
+  if (existingGlobalOp.has_value())
+    return existingGlobalOp.value();
+
+  // Generate a new global op
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(&moduleOp.front());
+
+  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
+      typeConverter->convertType(memrefType.getElementType()), 0);
+  std::string name = std::string(llvm::formatv("{0}_{1}", prefix, index));
+  // TODO: better alignment calculation
+  uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
+  return rewriter.create<LLVM::GlobalOp>(
+      funcOp->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
+      LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignmentByte,
+      mlir::gpu::GPUMemorySpace::kSharedMemorySpace);
+}
+
+LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
+    gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  MemRefType memrefType = op.getResultMemref().getType();
+  auto elementType = typeConverter->convertType(memrefType.getElementType());
+  assert(memrefType && "memref is not valid");
+
+  // Step 1: Generate a global symbol or existing for the dynamic shared
+  // memory
+  LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
+      rewriter, op, getTypeConverter(), alignmentBit);
+  assert(shmemOp && "cannot find module op or failed generating global op");
+
+  // Step 2. Get address of the global symbol
+  OpBuilder::InsertionGuard guard(rewriter);
+  rewriter.setInsertionPoint(op);
+  auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp);
+  Type baseType = basePtr->getResultTypes().front();
+
+  // Step 3. Fill mixed dynamic and static offsets
+  SmallVector<LLVM::GEPArg> gepArgs;
+  for (auto [idx, value] : llvm::enumerate(op.getStaticOffsets())) {
+    if (ShapedType::isDynamic(value))
+      gepArgs.push_back(LLVM::GEPArg(adaptor.getDynamicOffsets()[idx]));
+    else
+      gepArgs.push_back(LLVM::GEPArg(value));
+  }
+
+  // Step 4. Generate GEP using offsets
+  Type gepResultType = elementType;
+  if (memrefType.hasStaticShape()) {
+    for (int64_t numElem : llvm::reverse(memrefType.getShape())) {
+      gepResultType = LLVM::LLVMArrayType::get(gepResultType, numElem);
+    }
+  }
+  Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, gepResultType,
+                                                basePtr, gepArgs);
+
+  // Step 5. Create a memref descriptor
+  SmallVector<Value> shape, strides;
+  Value sizeBytes;
+  getMemRefDescriptorSizes(loc, memrefType, {}, rewriter, shape, strides,
+                           sizeBytes);
+
+  auto memRefDescriptor = this->createMemRefDescriptor(
+      loc, memrefType, shmemPtr, shmemPtr, shape, strides, rewriter);
+
+  // Step 6. Replace the op with memref descriptor
+  rewriter.replaceOp(op, {memRefDescriptor});
+  return success();
+}
+
 void mlir::populateGpuMemorySpaceAttributeConversions(
     TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
   typeConverter.addTypeAttributeConversion(
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index bd90286494d8035..1805e2b06f40481 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -14,6 +14,28 @@
 
 namespace mlir {
 
+/// Lowering for gpu.dynamic.shared.memory to LLVM dialect. The pattern first
+/// create a 0-sized global array symbol similar as LLVM expects. Subsequently,
+/// it computes the offset using 'getelementptr' with its offset arguments.
+/// Finally, it constructs a memref descriptor with these values and return it.
+struct GPUDynamicSharedMemoryOpLowering
+    : public ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp> {
+  using ConvertOpToLLVMPattern<
+      gpu::DynamicSharedMemoryOp>::ConvertOpToLLVMPattern;
+  GPUDynamicSharedMemoryOpLowering(const LLVMTypeConverter &converter,
+                                   unsigned alignmentBit = 0)
+      : ConvertOpToLLVMPattern<gpu::DynamicSharedMemoryOp>(converter),
+        alignmentBit(alignmentBit) {}
+
+  LogicalResult
+  matchAndRewrite(gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+
+private:
+  // Alignment bit
+  unsigned alignmentBit;
+};
+
 struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
   GPUFuncOpLowering(const LLVMTypeConverter &converter,
                     unsigned allocaAddrSpace, unsigned workgroupAddrSpace,
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 935e3d2a4095003..86a77f557cb9579 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -325,6 +325,9 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
            GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
           converter);
 
+  patterns.add<GPUDynamicSharedMemoryOpLowering>(
+      converter, NVVM::kSharedMemoryAlignmentBit);
+
   // Explicitly drop memory space when lowering private memory
   // attributions since NVVM models it as `alloca`s in the default
   // memory space and does not support `alloca`s with addrspace(5).
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 5eb2cadc884e151..cc86ad74ea3ca7d 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -164,17 +164,20 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
 // GPUDialect
 //===----------------------------------------------------------------------===//
 
-/// GPU memory space identifiers.
-enum GPUMemorySpace {
-  /// Generic memory space identifier.
-  kGenericMemorySpace = 0,
-
-  /// Global memory space identifier.
-  kGlobalMemorySpace = 1,
+bool GPUDialect::isWorkgroupMemoryAddressSpace(Attribute memorySpace) {
+  if (!memorySpace)
+    return false;
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
+    return intAttr.getInt() == GPUMemorySpace::kSharedMemorySpace;
+  if (auto gpuAttr = llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
+    return gpuAttr.getValue() == getWorkgroupAddressSpace();
+  return false;
+}
 
-  /// Shared memory space identifier.
-  kSharedMemorySpace = 3
-};
+bool GPUDialect::hasWorkgroupMemoryAddressSpace(MemRefType type) {
+  Attribute memorySpace = type.getMemorySpace();
+  return isWorkgroupMemoryAddressSpace(memorySpace);
+}
 
 bool GPUDialect::isKernel(Operation *op) {
   UnitAttr isKernelAttr = op->getAttrOfType<UnitAttr>(getKernelFuncAttrName());
@@ -612,13 +615,16 @@ void gpu::addAsyncDependency(Operation *op, Value token) {
 // LaunchOp
 //===----------------------------------------------------------------------===//
 
+static constexpr int64_t kDynamic = std::numeric_limits<int32_t>::min();
+
 void LaunchOp::build(OpBuilder &builder, OperationState &result,
                      Value gridSizeX, Value gridSizeY, Value gridSizeZ,
                      Value getBlockSizeX, Value getBlockSizeY,
                      Value getBlockSizeZ, Value dynamicSharedMemorySize,
                      Type asyncTokenType, ValueRange asyncDependencies,
                      TypeRange workgroupAttributions,
-                     TypeRange privateAttributions) {
+                     TypeRange privateAttributions,
+                     IntegerAttr dynamicSharedMemorySizeAttr) {
   // Add a WorkGroup attribution attribute. This attribute is required to
   // identify private attributions in the list of block argguments.
   result.addAttribute(getNumWorkgroupAttributionsAttrName(),
@@ -634,7 +640,10 @@ void LaunchOp::build(OpBuilder &builder, OperationState &result,
                       getBlockSizeY, getBlockSizeZ});
   if (dynamicSharedMemorySize)
     result.addOperands(dynamicSharedMemorySize);
-
+  if (!dynamicSharedMemorySizeAttr) 
+    dynamicSharedMemorySizeAttr = builder.getI32IntegerAttr(kDynamic);
+  
+  result.addAttribute("guray", dynamicSharedMemorySizeAttr);
   // Create a kernel body region with kNumConfigRegionAttributes + N memory
   // attributions, where the first kNumConfigRegionAttributes arguments have
   // `index` type and the rest have the same types as the data operands.
@@ -759,6 +768,11 @@ void LaunchOp::print(OpAsmPrinter &p) {
   if (getDynamicSharedMemorySize())
     p << ' ' << getDynamicSharedMemorySizeKeyword() << ' '
       << getDynamicSharedMemorySize();
+  else if(getGurayAttr()) {
+    p << ' ' << getDynamicSharedMemorySizeKeyword() << ' ' << getGurayAttr().getInt();
+    
+  }
+  
 
   printAttributions(p, getWorkgroupKeyword(), getWorkgroupAttributions());
   printAttributions(p, getPrivateKeyword(), getPrivateAttributions());
@@ -768,7 +782,8 @@ void LaunchOp::print(OpAsmPrinter &p) {
   p.printRegion(getBody(), /*printEntryBlockArgs=*/false);
   p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
                               LaunchOp::getOperandSegmentSizeAttr(),
-                              getNumWorkgroupAttributionsAttrName()});
+                              getNumWorkgroupAttributionsAttrName(),
+                              "guray"});
 }
 
 // Parse the size assignment blocks for blocks and threads.  These have the form
@@ -854,12 +869,19 @@ ParseResult LaunchOp::parse(OpAsmParser &parser, OperationState &result) {
   bool hasDynamicSharedMemorySize = false;
   if (!parser.parseOptionalKeyword(
           LaunchOp::getDynamicSharedMemorySizeKeyword())) {
-    hasDynamicSharedMemorySize = true;
-    if (parser.parseOperand(dynamicSharedMemorySize) ||
-        parser.resolveOperand(dynamicSharedMemorySize,
-                              parser.getBuilder().getI32Type(),
-                              result.operands))
-      return failure();
+    IntegerAttr shmemAttr;
+    OptionalParseResult shmemAttrResult =
+        parser.parseOptionalAttribute(shmemAttr, parser.getBuilder().getI32Type());
+    if(!shmemAttrResult.has_value()) {
+      hasDynamicSharedMemorySize = true;
+      shmemAttr = parser.getBuilder().getI32IntegerAttr(kDynamic);
+      if (parser.parseOperand(dynamicSharedMemorySize) ||
+          parser.resolveOperand(dynamicSharedMemorySize,
+                                parser.getBuilder().getI32Type(),
+                                result.operands))
+        return failure();
+    }    
+    result.addAttribute("guray", shmemAttr);
   }
 
   // Create the region arguments, it has kNumConfigRegionAttributes arguments
@@ -2024,6 +2046,78 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// DynamicSharedMemoryOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult gpu::DynamicSharedMemoryOp::verify() {
+  MemRefType memrefType = getResultMemref().getType();
+  unsigned long rank = memrefType.getRank();
+  unsigned long offset = getStaticOffsets().size();
+
+  // Number of offset can be one dimension larger the memref rank
+  if ((offset + 1) < rank) {
+    return emitOpError("Number of offset must match the rank of the memref");
+  }
+
+  // Check address space
+  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
+    return emitOpError() << "Address space must be "
+                         << gpu::AddressSpaceAttr::getMnemonic() << "<"
+                         << stringifyEnum(gpu::AddressSpace::Workgroup)
+                         << "> or " << int(GPUMemorySpace::kSharedMemorySpace)
+                         << ".";
+  }
+
+  if (memrefType.hasStaticShape()) {
+  ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/71516


More information about the Mlir-commits mailing list