[Mlir-commits] [mlir] [mlir][GPU] Lower gpu.memcpy with an offset to memcpy (PR #115687)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 10 20:49:25 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: None (NaOHCC)

<details>
<summary>Changes</summary>

This commit adds support for `gpu.memcpy` to handle memref types with offsets and contiguous memory layouts, allowing the lowering of the following IR:

```mlir
%view = memref.subview %mem[1,0] [1,2] [1,1] : memref<4x4xf32> to memref<2xf32, strided<[1], offset: 4>>
gpu.memcpy %d, %view : memref<2xf32>, memref<2xf32, strided<[1], offset: 4>>
```

Related discussion: [https://discourse.llvm.org/t/gpu-memcpy-does-not-support-generic-memref-layouts/80695](https://discourse.llvm.org/t/gpu-memcpy-does-not-support-generic-memref-layouts/80695)

---
Full diff: https://github.com/llvm/llvm-project/pull/115687.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h (+5) 
- (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+34-18) 
- (modified) mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp (+1-1) 
- (modified) mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp (+4) 
- (modified) mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir (+18-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index a761a77a407e87..8f3d8a73dd6154 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -31,6 +31,11 @@ namespace memref {
 /// contiguous chunk of memory.
 bool isStaticShapeAndContiguousRowMajor(MemRefType type);
 
+/// Return true, if the memref type has a rank and contains at least
+/// one dimension of size 0, indicating it is empty. UnrankedMemRefType is
+/// considered non-empty by this function.
+bool isEmpty(BaseMemRefType type);
+
 /// For a `memref` with `offset`, `sizes` and `strides`, returns the
 /// offset, size, and potentially the size padded at the front to use for the
 /// linearized `memref`.
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 92b28ff9c58737..ba05b4efb751b8 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -997,34 +998,49 @@ static Value bitAndAddrspaceCast(Location loc,
 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
     gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
-  auto memRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
+  auto srcMemRefType = cast<MemRefType>(memcpyOp.getSrc().getType());
+  auto dstMemRefType = cast<MemRefType>(memcpyOp.getDst().getType());
 
   if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
-      !isConvertibleAndHasIdentityMaps(memRefType) ||
       failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
     return failure();
 
+  auto isContiguousMemrefType = [&](MemRefType type) {
+    // We can use memcpy for memrefs if they have an identity layout or are
+    // contiguous with an arbitrary offset.
+    return !memref::isEmpty(type) &&
+           memref::isStaticShapeAndContiguousRowMajor(type);
+  };
+
+  if (!(isContiguousMemrefType(srcMemRefType) &&
+        isContiguousMemrefType(dstMemRefType)))
+    return rewriter.notifyMatchFailure(
+        memcpyOp, "Expected both operands to be non-empty memrefs with a "
+                  "static, contiguous row-major shape.");
+
   auto loc = memcpyOp.getLoc();
 
-  MemRefDescriptor srcDesc(adaptor.getSrc());
-  Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
+  auto getBufferPtr = [&](Value convertedOperand,
+                          MemRefType originalOperandType) {
+    MemRefDescriptor desc(convertedOperand);
+    return desc.bufferPtr(rewriter, loc, *getTypeConverter(),
+                          originalOperandType);
+  };
+
+  auto srcBufferPtr = getBufferPtr(adaptor.getSrc(), srcMemRefType);
+  auto dstBufferPtr = getBufferPtr(adaptor.getDst(), dstMemRefType);
 
-  Type elementPtrType = getElementPtrType(memRefType);
-  Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType);
-  Value gepPtr = rewriter.create<LLVM::GEPOp>(
-      loc, elementPtrType,
-      typeConverter->convertType(memRefType.getElementType()), nullPtr,
-      numElements);
-  auto sizeBytes =
-      rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
+  Value numElements = ConvertToLLVMPattern::getNumElements(
+      loc, srcMemRefType, /* dynamicSizes */ {}, rewriter);
+  // Get element size.
+  Value sizeInBytes =
+      getSizeInBytes(loc, srcMemRefType.getElementType(), rewriter);
+  Value sizeBytes = rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
 
-  auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
-                                 srcDesc.alignedPtr(rewriter, loc),
+  auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, srcBufferPtr,
+                                 *getTypeConverter());
+  auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, dstBufferPtr,
                                  *getTypeConverter());
-  auto dst = bitAndAddrspaceCast(
-      loc, rewriter, llvmPointerType,
-      MemRefDescriptor(adaptor.getDst()).alignedPtr(rewriter, loc),
-      *getTypeConverter());
 
   auto stream = adaptor.getAsyncDependencies().front();
   memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4bfa536cc8a44a..5446cc7bf36386 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -869,7 +869,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
       // special case handled by memrefCopy.
       return memrefType &&
              (memrefType.getLayout().isIdentity() ||
-              (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
+              (!memref::isEmpty(memrefType) &&
                memref::isStaticShapeAndContiguousRowMajor(memrefType)));
     };
 
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 6de744a7f75244..754de018d076b9 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -49,6 +49,10 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
   return curDim < 0;
 }
 
+bool isEmpty(BaseMemRefType type) {
+  return type.hasRank() && llvm::is_contained(type.getShape(), 0);
+}
+
 std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     OpBuilder &builder, Location loc, int srcBits, int dstBits,
     OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
diff --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
index 3f86b076982795..c427a3f468c17c 100644
--- a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
@@ -6,7 +6,7 @@ module attributes {gpu.container_module} {
   func.func @foo(%dst : memref<7xf32, 1>, %src : memref<7xf32>) {
     // CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
     %t0 = gpu.wait async
-    // CHECK: %[[size_bytes:.*]] = llvm.ptrtoint
+    // CHECK: %[[size_bytes:.*]] = llvm.mul
     // CHECK-NOT: llvm.addrspacecast
     // CHECK: %[[addr_cast:.*]] = llvm.addrspacecast
     // CHECK: llvm.call @mgpuMemcpy(%[[addr_cast]], %{{.*}}, %[[size_bytes]], %[[t0]])
@@ -16,4 +16,21 @@ module attributes {gpu.container_module} {
     gpu.wait [%t1]
     return
   }
+
+  // CHECK: func @test_copy_memref_with_offset
+  func.func @test_copy_memref_with_offset(%dst : memref<10xf32, strided<[1], offset: 8>>, %src : memref<10xf32, strided<[1], offset: 3>>) {
+    // CHECK: %[[stream:.*]] = llvm.call @mgpuStreamCreate
+    %t0 = gpu.wait async
+    // CHECK: %[[cst3:.*]] = llvm.mlir.constant(3 : index)
+    // CHECK: %[[src:.*]] = llvm.getelementptr %{{.*}}[%[[cst3]]]
+    // CHECK: %[[cst8:.*]] = llvm.mlir.constant(8 : index) 
+    // CHECK: %[[dst:.*]] = llvm.getelementptr %{{.*}}[%[[cst8]]]
+    // CHECK: %[[size_bytes:.*]] = llvm.mul
+    // CHECK: llvm.call @mgpuMemcpy(%[[dst]], %[[src]], %[[size_bytes]], %[[stream]])
+    %t1 = gpu.memcpy async [%t0] %dst, %src : memref<10xf32, strided<[1], offset: 8>>, memref<10xf32, strided<[1], offset: 3>>
+    // CHECK: llvm.call @mgpuStreamSynchronize(%[[stream]])
+    // CHECK: llvm.call @mgpuStreamDestroy(%[[stream]])
+    gpu.wait [%t1]
+    return
+  }
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/115687


More information about the Mlir-commits mailing list