[Mlir-commits] [mlir] [mlir][gpu] Allow gpu.dynamic_shared_memory return llvm.ptr (PR #96783)
Guray Ozen
llvmlistbot at llvm.org
Fri Jul 12 09:39:10 PDT 2024
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/96783
>From 9f305e9607bc8d741d04691dba6904ef2c2af5f5 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 26 Jun 2024 17:28:57 +0200
Subject: [PATCH 1/2] [mlir][gpu] Allow gpu.dynamic_shared_memory return
llvm.ptr
`gpu.dynamic_shared_memory` OP is very handy to get the dynamic shared memory pointer. However, it only works with memref.
This PR improves its support and allows OP to return `llvm.ptr` as well.
---
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 11 ++-
.../Conversion/GPUCommon/GPUOpsLowering.cpp | 89 +++++++++++--------
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 31 ++++---
.../Dialect/GPU/dynamic-shared-memory.mlir | 9 ++
mlir/test/Dialect/GPU/invalid.mlir | 2 +-
5 files changed, 89 insertions(+), 53 deletions(-)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index c57d291552e60..7c613e3231d20 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -586,7 +586,7 @@ def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [Pure]>
conveniently utilize `the dynamic_shared_memory_size` parameter of
`gpu.launch` for this purpose.
- Examples:
+ Example with memref:
```mlir
%0 = gpu.dynamic.shared.memory : memref<?xi8, #gpu.address_space<workgroup>>
%1 = memref.view %0[%c8192][] : memref<?xi8, #gpu.address_space<workgroup>>
@@ -594,10 +594,15 @@ def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [Pure]>
%2 = memref.view %0[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>>
to memref<32x64xf32, #gpu.address_space<workgroup>>
```
+
+ Example with llvm.ptr:
+ ```mlir
+ %0 = gpu.dynamic.shared.memory : !llvm.ptr<3>
+ ```
}];
let arguments = (ins);
- let results = (outs Arg<MemRefRankOf<[I8], [1]>>:$resultMemref);
- let assemblyFormat = [{ attr-dict `:` type($resultMemref) }];
+ let results = (outs AnyType:$result);
+ let assemblyFormat = [{ attr-dict `:` type($result) }];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 6053e34f30a41..08a926fd5caac 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/FormatVariadic.h"
@@ -559,21 +560,11 @@ static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
/// or uses existing symbol.
-LLVM::GlobalOp
-getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
- Operation *moduleOp, gpu::DynamicSharedMemoryOp op,
- const LLVMTypeConverter *typeConverter,
- MemRefType memrefType, unsigned alignmentBit) {
- uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();
-
- FailureOr<unsigned> addressSpace =
- typeConverter->getMemRefAddressSpace(memrefType);
- if (failed(addressSpace)) {
- op->emitError() << "conversion of memref memory space "
- << memrefType.getMemorySpace()
- << " to integer address space "
- "failed. Consider adding memory space conversions.";
- }
+LLVM::GlobalOp getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
+ Location loc, Operation *moduleOp,
+ unsigned addressSpace,
+ uint64_t alignmentByte,
+ Type elemType) {
// Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
// LLVM::GlobalOp is suitable for shared memory, return it.
@@ -582,7 +573,7 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
moduleOp->getRegion(0).front().getOps<LLVM::GlobalOp>()) {
existingGlobalNames.insert(globalOp.getSymName());
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
- if (globalOp.getAddrSpace() == addressSpace.value() &&
+ if (globalOp.getAddrSpace() == addressSpace &&
arrayType.getNumElements() == 0 &&
globalOp.getAlignment().value_or(0) == alignmentByte) {
return globalOp;
@@ -603,34 +594,54 @@ getDynamicSharedMemorySymbol(ConversionPatternRewriter &rewriter,
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(&moduleOp->getRegion(0).front().front());
- auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
- typeConverter->convertType(memrefType.getElementType()), 0);
+ auto zeroSizedArrayType = LLVM::LLVMArrayType::get(elemType, 0);
return rewriter.create<LLVM::GlobalOp>(
- op->getLoc(), zeroSizedArrayType, /*isConstant=*/false,
- LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte,
- addressSpace.value());
+ loc, zeroSizedArrayType, /*isConstant=*/false, LLVM::Linkage::Internal,
+ symName, /*value=*/Attribute(), alignmentByte, addressSpace);
}
LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
- MemRefType memrefType = op.getResultMemref().getType();
- Type elementType = typeConverter->convertType(memrefType.getElementType());
- // Step 1: Generate a memref<0xi8> type
- MemRefLayoutAttrInterface layout = {};
- auto memrefType0sz =
- MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+ unsigned addressSpace;
+ Type elementType;
+ uint64_t alignmentByte;
+ MemRefType memrefType0sz;
+
+ // Step 1. Find out the element type, alignment and address space
+ if (MemRefType memrefType =
+ llvm::dyn_cast<MemRefType>(op.getResult().getType())) {
+ elementType = typeConverter->convertType(memrefType.getElementType());
+ MemRefLayoutAttrInterface layout = {};
+ memrefType0sz =
+ MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());
+
+ alignmentByte = alignmentBit / memrefType0sz.getElementTypeBitWidth();
+ FailureOr<unsigned> maybeAddressSpace =
+ getTypeConverter()->getMemRefAddressSpace(memrefType0sz);
+ if (failed(maybeAddressSpace)) {
+ op->emitError() << "conversion of memref memory space "
+ << memrefType0sz.getMemorySpace()
+ << " to integer address space "
+ "failed. Consider adding memory space conversions.";
+ }
+ addressSpace = maybeAddressSpace.value();
+ } else {
+ auto ptr = dyn_cast<LLVM::LLVMPointerType>(op.getResult().getType());
+ addressSpace = ptr.getAddressSpace();
+ elementType = IntegerType::get(op->getContext(), 8);
+ alignmentByte = alignmentBit / elementType.getIntOrFloatBitWidth();
+ }
// Step 2: Generate a global symbol or existing for the dynamic shared
// memory with memref<0xi8> type
LLVM::LLVMFuncOp funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
- LLVM::GlobalOp shmemOp = {};
Operation *moduleOp = funcOp->getParentWithTrait<OpTrait::SymbolTable>();
- shmemOp = getDynamicSharedMemorySymbol(
- rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);
+ LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
+ rewriter, loc, moduleOp, addressSpace, alignmentByte, elementType);
// Step 3. Get address of the global symbol
OpBuilder::InsertionGuard guard(rewriter);
@@ -643,15 +654,17 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType,
basePtr, gepArgs);
// Step 5. Create a memref descriptor
- SmallVector<Value> shape, strides;
- Value sizeBytes;
- getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
- sizeBytes);
- auto memRefDescriptor = this->createMemRefDescriptor(
- loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);
-
+ Value result = shmemPtr;
+ if (llvm::isa<MemRefType>(op.getResult().getType())) {
+ SmallVector<Value> shape, strides;
+ Value sizeBytes;
+ getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
+ sizeBytes);
+ result = this->createMemRefDescriptor(loc, memrefType0sz, shmemPtr,
+ shmemPtr, shape, strides, rewriter);
+ }
// Step 5. Replace the op with memref descriptor
- rewriter.replaceOp(op, {memRefDescriptor});
+ rewriter.replaceOp(op, {result});
return success();
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 3abaa3b3a81dd..28398c9468d2b 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -2219,19 +2220,27 @@ gpu::SelectObjectAttr::verify(function_ref<InFlightDiagnostic()> emitError,
LogicalResult gpu::DynamicSharedMemoryOp::verify() {
if (!getOperation()->getParentWithTrait<OpTrait::SymbolTable>())
return emitOpError() << "must be inside an op with symbol table";
-
- MemRefType memrefType = getResultMemref().getType();
- // Check address space
- if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
- return emitOpError() << "address space must be "
- << gpu::AddressSpaceAttr::getMnemonic() << "<"
- << stringifyEnum(gpu::AddressSpace::Workgroup) << ">";
+ if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(getResult().getType())) {
+ return success();
}
- if (memrefType.hasStaticShape()) {
- return emitOpError() << "result memref type must be memref<?xi8, "
- "#gpu.address_space<workgroup>>";
+ if (MemRefType memrefType =
+ llvm::dyn_cast<MemRefType>(getResult().getType())) {
+ // Check address space
+ if (!GPUDialect::hasWorkgroupMemoryAddressSpace(memrefType)) {
+ return emitOpError() << "address space must be "
+ << gpu::AddressSpaceAttr::getMnemonic() << "<"
+ << stringifyEnum(gpu::AddressSpace::Workgroup)
+ << ">";
+ }
+ if (memrefType.hasStaticShape() ||
+ !memrefType.getElementType().isInteger(8)) {
+ return emitOpError() << "result memref type must be memref<?xi8, "
+ "#gpu.address_space<workgroup>>";
+ }
+ return success();
}
- return success();
+ return emitOpError() << "result type must be either llvm.ptr or memref<?xi8, "
+ "#gpu.address_space<workgroup>>";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
index d73125fd763e6..75d7fb88e4dd8 100644
--- a/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
+++ b/mlir/test/Dialect/GPU/dynamic-shared-memory.mlir
@@ -99,4 +99,13 @@ gpu.module @modules {
func.return
}
+
+// CHECK-LABEL: llvm.func @func_device_function_plain_pointer
+func.func @func_device_function_plain_pointer() {
+ // CHECK-DAG: %[[S5:.+]] = llvm.mlir.addressof @__dynamic_shmem__3 : !llvm.ptr<3>
+ // CHECK: "test.use.shared.memory"(%[[S5]]) : (!llvm.ptr<3>) -> ()
+ %shmem = gpu.dynamic_shared_memory : !llvm.ptr<3>
+ "test.use.shared.memory"(%shmem) : (!llvm.ptr<3>) -> ()
+ func.return
+}
}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index e9d8f329be8ed..ecebea59e1964 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -818,7 +818,7 @@ func.func @main(%arg0 : index) {
threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1)
dynamic_shared_memory_size %shmemSize
{
- // expected-error @below {{'gpu.dynamic_shared_memory' op result #0 must be 1D memref of 8-bit signless integer values, but got 'memref<?xf32, #gpu.address_space<workgroup>}}
+ // expected-error @below {{'gpu.dynamic_shared_memory' op result memref type must be memref<?xi8, #gpu.address_space<workgroup>>}}
%0 = gpu.dynamic_shared_memory : memref<?xf32, #gpu.address_space<workgroup>>
gpu.terminator
}
>From f0cec0c02f232a864ba4dda473cfcd55d81d0e5e Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 12 Jul 2024 18:39:01 +0200
Subject: [PATCH 2/2] Update GPUOpsLowering.cpp
---
mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 08a926fd5caac..05290215b1498 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -630,7 +630,7 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
}
addressSpace = maybeAddressSpace.value();
} else {
- auto ptr = dyn_cast<LLVM::LLVMPointerType>(op.getResult().getType());
+ auto ptr = cast<LLVM::LLVMPointerType>(op.getResult().getType());
addressSpace = ptr.getAddressSpace();
elementType = IntegerType::get(op->getContext(), 8);
alignmentByte = alignmentBit / elementType.getIntOrFloatBitWidth();
More information about the Mlir-commits
mailing list