[Mlir-commits] [mlir] [mlir][gpu] Allow gpu.dynamic_shared_memory return llvm.ptr (PR #96783)
Matthias Springer
llvmlistbot at llvm.org
Fri Jul 12 00:54:32 PDT 2024
================
@@ -603,34 +594,54 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(&moduleOp->getRegion(0).front().front());
- auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
- typeConverter->convertType(memrefType.getElementType()), 0);
+ auto zeroSizedArrayType = LLVM::LLVMArrayType::get(elemType, 0);
return rewriter.create<LLVM::GlobalOp>(
- op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
- LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
- addressSpace.value());
+ loc, zeroSizedArrayType, /*isConstant=*/false, LLVM::Linkage::Internal,
+ symName, /*value=*/Attribute(), alignmentByte, addressSpace);
}
LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- MemRefType memrefType = op.getResultMemref().getType();
- Type elementType = typeConverter->convertType(memrefType.getElementType());
- // Step 1: Generate a memref<0xi8> type
- MemRefLayoutAttrInterface layout = {};
- auto memrefType0sz =
- MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+ unsigned addressSpace;
+ Type elementType;
+ uint64_t alignmentByte;
+ MemRefType memrefType0sz;
+
+ // Step 1. Find out the element type, alignment and address space
+ if (MemRefType memrefType =
+ llvm::dyn_cast<MemRefType>(op.getResult().getType())) {
+ elementType = typeConverter->convertType(memrefType.getElementType());
+ MemRefLayoutAttrInterface layout = {};
+ memrefType0sz =
+ MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+
+ alignmentByte = alignmentBit / memrefType0sz.getElementTypeBitWidth();
+ FailureOr<unsigned> maybeAddressSpace =
+ getTypeConverter()->getMemRefAddressSpace(memrefType0sz);
+ if (failed(maybeAddressSpace)) {
+ op->emitError() << "conversion of memref memory space "
+ << memrefType0sz.getMemorySpace()
+ << " to integer address space "
+ "failed. Consider adding memory space conversions.";
+ }
+ addressSpace = maybeAddressSpace.value();
+ } else {
+ auto ptr = dyn_cast<LLVM::LLVMPointerType>(op.getResult().getType());
----------------
matthias-springer wrote:
`cast<LLVM::LLVMPointerType>`
https://github.com/llvm/llvm-project/pull/96783
More information about the Mlir-commits
mailing list