[Mlir-commits] [mlir] Fix bug in gpu.memcpy lowering for dynamically shaped operands. (PR #128820)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 25 21:29:25 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Arnab Dutta  (arnab-polymage)

<details>
<summary>Changes</summary>

Compute the number of elements to be copied by multiplying dim sizes along all the dimensions.

---
Full diff: https://github.com/llvm/llvm-project/pull/128820.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+10-8) 
- (modified) mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir (+20) 


``````````diff
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 8017eb6bb383b..512820bab4097 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -76,14 +76,16 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
                        MemRefType type, MemRefDescriptor desc) const {
     Type indexType = ConvertToLLVMPattern::getIndexType();
-    return type.hasStaticShape()
-               ? ConvertToLLVMPattern::createIndexAttrConstant(
-                     rewriter, loc, indexType, type.getNumElements())
-               // For identity maps (verified by caller), the number of
-               // elements is stride[0] * size[0].
-               : rewriter.create<LLVM::MulOp>(loc,
-                                              desc.stride(rewriter, loc, 0),
-                                              desc.size(rewriter, loc, 0));
+    if (type.hasStaticShape())
+      return ConvertToLLVMPattern::createIndexAttrConstant(
+          rewriter, loc, indexType, type.getNumElements());
+    // Compute the number of elements by multiplying all the dim sizes.
+    uint64_t rank = type.getRank();
+    Value numElements = desc.size(rewriter, loc, /*pos=*/0);
+    for (unsigned i = 1; i < rank; i++)
+      numElements = rewriter.create<LLVM::MulOp>(
+          loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
+    return numElements;
   }
 
   MLIRContext *context = &this->getTypeConverter()->getContext();
diff --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
index 3f86b07698279..b45d188a77e3f 100644
--- a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
@@ -17,3 +17,23 @@ module attributes {gpu.container_module} {
     return
   }
 }
+
+// -----
+
+module attributes {gpu.container_module} {
+
+  // CHECK: func @dynamic
+  func.func @dynamic(%dst : memref<?x?xf32, 1>, %src : memref<?x?xf32>) {
+    // CHECK: %[[T0:.*]] = llvm.call @mgpuStreamCreate
+    %t0 = gpu.wait async
+    %t1 = gpu.memcpy async [%t0] %dst, %src : memref<?x?xf32, 1>, memref<?x?xf32>
+    // CHECK: %[[DIM_SIZE_0:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    // CHECK-NEXT: %[[DIM_SIZE_1:.*]] = llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    // CHECK: %[[NUM_ELEMENTS:.*]] = llvm.mul %[[DIM_SIZE_0]], %[[DIM_SIZE_1]]  : i64
+    // CHECK: %[[SIZE_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[NUM_ELEMENTS]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: %[[SIZE_INT:.*]] = llvm.ptrtoint %[[SIZE_PTR]] : !llvm.ptr to i64
+    // CHECK: %[[ADDR_CAST:.*]] = llvm.addrspacecast
+    // CHECK: llvm.call @mgpuMemcpy(%[[ADDR_CAST]], %{{.*}}, %[[SIZE_INT]], %[[T0]])
+    return
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/128820


More information about the Mlir-commits mailing list