[Mlir-commits] [mlir] [mlir][gpu] Allow gpu.dynamic_shared_memory return llvm.ptr (PR #96783)

Guray Ozen llvmlistbot at llvm.org
Fri Jul 12 09:39:10 PDT 2024


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/96783

>From 9f305e9607bc8d741d04691dba6904ef2c2af5f5 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 26 Jun 2024 17:28:57 +0200
Subject: [PATCH 1/2] [mlir][gpu] Allow gpu.dynamic_shared_memory return
 llvm.ptr

`gpu.dynamic_shared_memory` OP is very handy to get the dynamic shared memory pointer. However, it only works with memref.

This PR improves its support and allows OP to return `llvm.ptr` as well.
---
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td    | 11 ++-
 .../Conversion/GPUCommon/GPUOpsLowering.cpp   | 89 +++++++++++--------
 mlir/lib/Dialect/GPU/IR/GPUDialect.cpp        | 31 ++++---
 .../Dialect/GPU/dynamic-shared-memory.mlir    |  9 ++
 mlir/test/Dialect/GPU/invalid.mlir            |  2 +-
 5 files changed, 89 insertions(+), 53 deletions(-)

diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index c57d291552e60..7c613e3231d20 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -586,7 +586,7 @@ def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [Pure]>
     conveniently utilize `the dynamic_shared_memory_size` parameter of
     `gpu.launch` for this purpose.
 
-    Examples:
+    Example with memref:
     ```mlir
     %0 = gpu.dynamic.shared.memory : memref<?xi8, #gpu.address_space<workgroup>>
     %1 = memref.view %0[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>>
@@ -594,10 +594,15 @@ def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [Pure]>
     %2 = memref.view %0[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>>
                             to memref<32x64xf32, #gpu.address_space<workgroup>>
     ```
+
+    Example with llvm.ptr:
+    ```mlir
+    %0 = gpu.dynamic.shared.memory : !llvm.ptr<3>    
+    ```
   }];
   let arguments = (ins);
-  let results = (outs Arg<MemRefRankOf<[I8], [1]>>:$resultMemref);
-  let assemblyFormat = [{ attr-dict `:` type($resultMemref) }];
+  let results = (outs AnyType:$result);
+  let assemblyFormat = [{ attr-dict `:` type($result) }];
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6053e34f30a41..08a926fd5caac 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -559,21 +560,11 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
 
 /// Generates a symbol with 0-sized array type for dynamic shared memory usage,
 /// or uses existing symbol.
-LLVM::GlobalOp
-getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
-                             Operation *moduleOp, gpu::DynamicSharedMemoryOp op,
-                             const LLVMTypeConverter *typeConverter,
-                             MemRefType memrefType, unsigned alignmentBit) {
-  uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
-
-  FailureOr<unsigned> addressSpace =
-      typeConverter->getMemRefAddressSpace(memrefType);
-  if (failed(addressSpace)) {
-    op->emitError() << "conversion of memref memory space "
-                    << memrefType.getMemorySpace()
-                    << " to integer address space "
-                       "failed. Consider adding memory space conversions.";
-  }
+LLVM::GlobalOp getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
+                                            Location loc, Operation *moduleOp,
+                                            unsigned addressSpace,
+                                            uint64_t alignmentByte,
+                                            Type elemType) {
 
   // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
   // LLVM::GlobalOp is suitable for shared memory, return it.
@@ -582,7 +573,7 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
        moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) {
     existingGlobalNames.insert(globalOp.getSymName());
     if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
-      if (globalOp.getAddrSpace() == addressSpace.value() &&
+      if (globalOp.getAddrSpace() == addressSpace &&
           arrayType.getNumElements() == 0 &&
           globalOp.getAlignment().value_or(0) == alignmentByte) {
         return globalOp;
@@ -603,34 +594,54 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(&moduleOp->getRegion(0).front().front());
 
-  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
-      typeConverter->convertType(memrefType.getElementType()), 0);
+  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(elemType, 0);
 
   return rewriter.create<LLVM::GlobalOp>(
-      op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
-      LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
-      addressSpace.value());
+      loc, zeroSizedArrayType, /*isConstant=*/false, LLVM::Linkage::Internal,
+      symName, /*value=*/Attribute(), alignmentByte, addressSpace);
 }
 
 LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
     gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  MemRefType memrefType = op.getResultMemref().getType();
-  Type elementType = typeConverter->convertType(memrefType.getElementType());
 
-  // Step 1: Generate a memref<0xi8> type
-  MemRefLayoutAttrInterface layout = {};
-  auto memrefType0sz =
-      MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+  unsigned addressSpace;
+  Type elementType;
+  uint64_t alignmentByte;
+  MemRefType memrefType0sz;
+
+  // Step 1. Find out the element type, alignment and address space
+  if (MemRefType memrefType =
+          llvm::dyn_cast<MemRefType>(op.getResult().getType())) {
+    elementType = typeConverter->convertType(memrefType.getElementType());
+    MemRefLayoutAttrInterface layout = {};
+    memrefType0sz =
+        MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+
+    alignmentByte = alignmentBit / memrefType0sz.getElementTypeBitWidth();
+    FailureOr<unsigned> maybeAddressSpace =
+        getTypeConverter()->getMemRefAddressSpace(memrefType0sz);
+    if (failed(maybeAddressSpace)) {
+      op->emitError() << "conversion of memref memory space "
+                      << memrefType0sz.getMemorySpace()
+                      << " to integer address space "
+                         "failed. Consider adding memory space conversions.";
+    }
+    addressSpace = maybeAddressSpace.value();
+  } else {
+    auto ptr = dyn_cast<LLVM::LLVMPointerType>(op.getResult().getType());
+    addressSpace = ptr.getAddressSpace();
+    elementType = IntegerType::get(op->getContext(), 8);
+    alignmentByte = alignmentBit / elementType.getIntOrFloatBitWidth();
+  }
 
   // Step 2: Generate a global symbol or existing for the dynamic shared
   // memory with memref<0xi8> type
   LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
-  LLVM::GlobalOp shmemOp = {};
   Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
-  shmemOp = getDynamicSharedMemorySymbol(
-      rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
+  LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
+      rewriter, loc, moduleOp, addressSpace, alignmentByte, elementType);
 
   // Step 3. Get address of the global symbol
   OpBuilder::InsertionGuard guard(rewriter);
@@ -643,15 +654,17 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
   Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
                                                 basePtr, gepArgs);
   // Step 5. Create a memref descriptor
-  SmallVector<Value> shape, strides;
-  Value sizeBytes;
-  getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
-                           sizeBytes);
-  auto memRefDescriptor = this->createMemRefDescriptor(
-      loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
-
+  Value result = shmemPtr;
+  if (llvm::isa<MemRefType>(op.getResult().getType())) {
+    SmallVector<Value> shape, strides;
+    Value sizeBytes;
+    getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
+                             sizeBytes);
+    result = this->createMemRefDescriptor(loc, memrefType0sz, shmemPtr,
+                                          shmemPtr, shape, strides, rewriter);
+  }
   // Step 5. Replace the op with memref descriptor
-  rewriter.replaceOp(op, {memRefDescriptor});
+  rewriter.replaceOp(op, {result});
   return success();
 }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 3abaa3b3a81dd..28398c9468d2b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -2219,19 +2220,27 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 LogicalResult gpu::DynamicSharedMemoryOp::verify() {
   if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
     return emitOpError() << "must be inside an op with symbol table";
-
-  MemRefType memrefType = getResultMemref().getType();
-  // Check address space
-  if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
-    return emitOpError() << "address space must be "
-                         << gpu::AddressSpaceAttr::getMnemonic() << "<"
-                         << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
+  if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(getResult().getType())) {
+    return success();
   }
-  if (memrefType.hasStaticShape()) {
-    return emitOpError() << "result memref type must be memref<?xi8, "
-                            "#gpu.address_space<workgroup>>";
+  if (MemRefType memrefType =
+          llvm::dyn_cast<MemRefType>(getResult().getType())) {
+    // Check address space
+    if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
+      return emitOpError() << "address space must be "
+                           << gpu::AddressSpaceAttr::getMnemonic() << "<"
+                           << stringifyEnum(gpu::AddressSpace::Workgroup)
+                           << ">";
+    }
+    if (memrefType.hasStaticShape() ||
+        !memrefType.getElementType().isInteger(8)) {
+      return emitOpError() << "result memref type must be memref<?xi8, "
+                              "#gpu.address_space<workgroup>>";
+    }
+    return success();
   }
-  return success();
+  return emitOpError() << "result type must be either llvm.ptr or memref<?xi8, "
+                          "#gpu.address_space<workgroup>>";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index d73125fd763e6..75d7fb88e4dd8 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -99,4 +99,13 @@ gpu.module @modules {
 
     func.return
   }
+
+// CHECK-LABEL: llvm.func @func_device_function_plain_pointer
+func.func @func_device_function_plain_pointer()  {
+  // CHECK-DAG: %[[S5:.+]] = llvm.mlir.addressof @__dynamic_shmem__3 : !llvm.ptr<3>
+  //     CHECK: "test.use.shared.memory"(%[[S5]]) : (!llvm.ptr<3>) -> ()
+  %shmem = gpu.dynamic_shared_memory : !llvm.ptr<3>
+  "test.use.shared.memory"(%shmem) : (!llvm.ptr<3>) -> ()
+  func.return
+}
 }
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index e9d8f329be8ed..ecebea59e1964 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -818,7 +818,7 @@ func.func @main(%arg0 : index) {
              threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
              dynamic_shared_memory_size %shmemSize
   {
-    // expected-error @below {{'gpu.dynamic_shared_memory' op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, #gpu.address_space<workgroup>}}
+    // expected-error @below {{'gpu.dynamic_shared_memory' op result memref type must be memref<?xi8, #gpu.address_space<workgroup>>}}
     %0 = gpu.dynamic_shared_memory : memref<?xf32, #gpu.address_space<workgroup>>
     gpu.terminator
   }

>From f0cec0c02f232a864ba4dda473cfcd55d81d0e5e Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 12 Jul 2024 18:39:01 +0200
Subject: [PATCH 2/2] Update GPUOpsLowering.cpp

---
 mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 08a926fd5caac..05290215b1498 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -630,7 +630,7 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
     }
     addressSpace = maybeAddressSpace.value();
   } else {
-    auto ptr = dyn_cast<LLVM::LLVMPointerType>(op.getResult().getType());
+    auto ptr = cast<LLVM::LLVMPointerType>(op.getResult().getType());
     addressSpace = ptr.getAddressSpace();
     elementType = IntegerType::get(op->getContext(), 8);
     alignmentByte = alignmentBit / elementType.getIntOrFloatBitWidth();



More information about the Mlir-commits mailing list