[Mlir-commits] [mlir] [mlir][gpu] Allow gpu.dynamic_shared_memory return llvm.ptr (PR #96783)

Matthias Springer llvmlistbot at llvm.org
Fri Jul 12 00:54:32 PDT 2024


================
@@ -603,34 +594,54 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(&moduleOp->getRegion(0).front().front());
 
-  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
-      typeConverter->convertType(memrefType.getElementType()), 0);
+  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(elemType, 0);
 
   return rewriter.create<LLVM::GlobalOp>(
-      op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
-      LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
-      addressSpace.value());
+      loc, zeroSizedArrayType, /*isConstant=*/false, LLVM::Linkage::Internal,
+      symName, /*value=*/Attribute(), alignmentByte, addressSpace);
 }
 
 LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
     gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  MemRefType memrefType = op.getResultMemref().getType();
-  Type elementType = typeConverter->convertType(memrefType.getElementType());
 
-  // Step 1: Generate a memref<0xi8> type
-  MemRefLayoutAttrInterface layout = {};
-  auto memrefType0sz =
-      MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+  unsigned addressSpace;
+  Type elementType;
+  uint64_t alignmentByte;
+  MemRefType memrefType0sz;
+
+  // Step 1. Find out the element type, alignment and address space
+  if (MemRefType memrefType =
+          llvm::dyn_cast<MemRefType>(op.getResult().getType())) {
+    elementType = typeConverter->convertType(memrefType.getElementType());
+    MemRefLayoutAttrInterface layout = {};
+    memrefType0sz =
+        MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+
+    alignmentByte = alignmentBit / memrefType0sz.getElementTypeBitWidth();
+    FailureOr<unsigned> maybeAddressSpace =
+        getTypeConverter()->getMemRefAddressSpace(memrefType0sz);
+    if (failed(maybeAddressSpace)) {
+      op->emitError() << "conversion of memref memory space "
+                      << memrefType0sz.getMemorySpace()
+                      << " to integer address space "
+                         "failed. Consider adding memory space conversions.";
+    }
+    addressSpace = maybeAddressSpace.value();
+  } else {
+    auto ptr = dyn_cast<LLVM::LLVMPointerType>(op.getResult().getType());
----------------
matthias-springer wrote:

`cast<LLVM::LLVMPointerType>`

https://github.com/llvm/llvm-project/pull/96783


More information about the Mlir-commits mailing list