[Mlir-commits] [mlir] [mlir][GPU] Lower gpu.memcpy with an offset to memcpy (PR #115687)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 10 20:49:25 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: None (NaOHCC)
<details>
<summary>Changes</summary>
This commit adds support for `gpu.memcpy` to handle memref types with offsets and contiguous memory layouts, allowing the lowering of the following IR:
```mlir
%view = memref.subview %mem[1,0] [1,2] [1,1] : memref<4x4xf32> to memref<2xf32, strided<[1], offset: 4>>
gpu.memcpy %d, %view : memref<2xf32>, memref<2xf32, strided<[1], offset: 4>>
```
Related discussion: [https://discourse.llvm.org/t/gpu-memcpy-does-not-support-generic-memref-layouts/80695](https://discourse.llvm.org/t/gpu-memcpy-does-not-support-generic-memref-layouts/80695)
---
Full diff: https://github.com/llvm/llvm-project/pull/115687.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h (+5)
- (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+34-18)
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+1-1)
- (modified) mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp (+4)
- (modified) mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir (+18-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index a761a77a407e87..8f3d8a73dd6154 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -31,6 +31,11 @@ namespace memref {
/// contiguous chunk of memory.
bool isStaticShapeAndContiguousRowMajor(MemRefType type);
+/// Return true, if the memref type has a rank and contains at least
+/// one dimension of size 0, indicating it is empty. UnrankedMemRefType is
+/// considered non-empty by this function.
+bool isEmpty(BaseMemRefType type);
+
/// For a `memref` with `offset`, `sizes` and `strides`, returns the
/// offset, size, and potentially the size padded at the front to use for the
/// linearized `memref`.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 92b28ff9c58737..ba05b4efb751b8 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -31,6 +31,7 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -997,34 +998,49 @@ static Value bitAndAddrspaceCast(Location loc,
LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
- auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
+ auto srcMemRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
+ auto dstMemRefType = cast<MemRefType>(memcpyOp.getDst().getType());
if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
- !isConvertibleAndHasIdentityMaps(memRefType) ||
failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
return failure();
+ auto isContiguousMemrefType = [&](MemRefType type) {
+ // We can use memcpy for memrefs if they have an identity layout or are
+ // contiguous with an arbitrary offset.
+ return !memref::isEmpty(type) &&
+ memref::isStaticShapeAndContiguousRowMajor(type);
+ };
+
+ if (!(isContiguousMemrefType(srcMemRefType) &&
+ isContiguousMemrefType(dstMemRefType)))
+ return rewriter.notifyMatchFailure(
+ memcpyOp, "Expected both operands to be non-empty memrefs with a "
+ "static, contiguous row-major shape.");
+
auto loc = memcpyOp.getLoc();
- MemRefDescriptor srcDesc(adaptor.getSrc());
- Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
+ auto getBufferPtr = [&](Value convertedOperand,
+ MemRefType originalOperandType) {
+ MemRefDescriptor desc(convertedOperand);
+ return desc.bufferPtr(rewriter, loc, *getTypeConverter(),
+ originalOperandType);
+ };
+
+ auto srcBufferPtr = getBufferPtr(adaptor.getSrc(), srcMemRefType);
+ auto dstBufferPtr = getBufferPtr(adaptor.getDst(), dstMemRefType);
- Type elementPtrType = getElementPtrType(memRefType);
- Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
- Value gepPtr = rewriter.create<LLVM::GEPOp>(
- loc, elementPtrType,
- typeConverter->convertType(memRefType.getElementType()), nullPtr,
- numElements);
- auto sizeBytes =
- rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
+ Value numElements = ConvertToLLVMPattern::getNumElements(
+ loc, srcMemRefType, /* dynamicSizes */ {}, rewriter);
+ // Get element size.
+ Value sizeInBytes =
+ getSizeInBytes(loc, srcMemRefType.getElementType(), rewriter);
+ Value sizeBytes = rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
- auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
- srcDesc.alignedPtr(rewriter, loc),
+ auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, srcBufferPtr,
+ *getTypeConverter());
+ auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, dstBufferPtr,
*getTypeConverter());
- auto dst = bitAndAddrspaceCast(
- loc, rewriter, llvmPointerType,
- MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
- *getTypeConverter());
auto stream = adaptor.getAsyncDependencies().front();
memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..5446cc7bf36386 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -869,7 +869,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
// special case handled by memrefCopy.
return memrefType &&
(memrefType.getLayout().isIdentity() ||
- (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
+ (!memref::isEmpty(memrefType) &&
memref::isStaticShapeAndContiguousRowMajor(memrefType)));
};
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 6de744a7f75244..754de018d076b9 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -49,6 +49,10 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
return curDim < 0;
}
+bool isEmpty(BaseMemRefType type) {
+ return type.hasRank() && llvm::is_contained(type.getShape(), 0);
+}
+
std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
OpBuilder &builder, Location loc, int srcBits, int dstBits,
OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
diff --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
index 3f86b076982795..c427a3f468c17c 100644
--- a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
@@ -6,7 +6,7 @@ module attributes {gpu.container_module} {
func.func @foo(%dst : memref<7xf32, 1>, %src : memref<7xf32>) {
// CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
%t0 = gpu.wait async
- // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint
+ // CHECK: %[[size_bytes:.*]] = llvm.mul
// CHECK-NOT: llvm.addrspacecast
// CHECK: %[[addr_cast:.*]] = llvm.addrspacecast
// CHECK: llvm.call @mgpuMemcpy(%[[addr_cast]], %{{.*}}, %[[size_bytes]], %[[t0]])
@@ -16,4 +16,21 @@ module attributes {gpu.container_module} {
gpu.wait [%t1]
return
}
+
+ // CHECK: func @test_copy_memref_with_offset
+ func.func @test_copy_memref_with_offset(%dst : memref<10xf32, strided<[1], offset: 8>>, %src : memref<10xf32, strided<[1], offset: 3>>) {
+ // CHECK: %[[stream:.*]] = llvm.call @mgpuStreamCreate
+ %t0 = gpu.wait async
+ // CHECK: %[[cst3:.*]] = llvm.mlir.constant(3 : index)
+ // CHECK: %[[src:.*]] = llvm.getelementptr %{{.*}}[%[[cst3]]]
+ // CHECK: %[[cst8:.*]] = llvm.mlir.constant(8 : index)
+ // CHECK: %[[dst:.*]] = llvm.getelementptr %{{.*}}[%[[cst8]]]
+ // CHECK: %[[size_bytes:.*]] = llvm.mul
+ // CHECK: llvm.call @mgpuMemcpy(%[[dst]], %[[src]], %[[size_bytes]], %[[stream]])
+ %t1 = gpu.memcpy async [%t0] %dst, %src : memref<10xf32, strided<[1], offset: 8>>, memref<10xf32, strided<[1], offset: 3>>
+ // CHECK: llvm.call @mgpuStreamSynchronize(%[[stream]])
+ // CHECK: llvm.call @mgpuStreamDestroy(%[[stream]])
+ gpu.wait [%t1]
+ return
+ }
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/115687
More information about the Mlir-commits
mailing list