[Mlir-commits] [mlir] c13ebb5 - Fix bug in gpu.memcpy lowering for dynamically shaped operands. (#128820)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 3 00:23:12 PST 2025


Author: Arnab Dutta
Date: 2025-03-03T13:53:09+05:30
New Revision: c13ebb527961e96e96ec1913dbbbcc6782512e18

URL: https://github.com/llvm/llvm-project/commit/c13ebb527961e96e96ec1913dbbbcc6782512e18
DIFF: https://github.com/llvm/llvm-project/commit/c13ebb527961e96e96ec1913dbbbcc6782512e18.diff

LOG: Fix bug in gpu.memcpy lowering for dynamically shaped operands. (#128820)

Compute the number of elements to be copied by multiplying dim sizes
along all the dimensions.

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 8017eb6bb383b..512820bab4097 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -76,14 +76,16 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
                        MemRefType type, MemRefDescriptor desc) const {
     Type indexType = ConvertToLLVMPattern::getIndexType();
-    return type.hasStaticShape()
-               ? ConvertToLLVMPattern::createIndexAttrConstant(
-                     rewriter, loc, indexType, type.getNumElements())
-               // For identity maps (verified by caller), the number of
-               // elements is stride[0] * size[0].
-               : rewriter.create<LLVM::MulOp>(loc,
-                                              desc.stride(rewriter, loc, 0),
-                                              desc.size(rewriter, loc, 0));
+    if (type.hasStaticShape())
+      return ConvertToLLVMPattern::createIndexAttrConstant(
+          rewriter, loc, indexType, type.getNumElements());
+    // Compute the number of elements by multiplying all the dim sizes.
+    uint64_t rank = type.getRank();
+    Value numElements = desc.size(rewriter, loc, /*pos=*/0);
+    for (unsigned i = 1; i < rank; i++)
+      numElements = rewriter.create<LLVM::MulOp>(
+          loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
+    return numElements;
   }
 
   MLIRContext *context = &this->getTypeConverter()->getContext();

diff  --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
index 3f86b07698279..b45d188a77e3f 100644
--- a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
@@ -17,3 +17,23 @@ module attributes {gpu.container_module} {
     return
   }
 }
+
+// -----
+
+module attributes {gpu.container_module} {
+
+  // CHECK: func @dynamic
+  func.func @dynamic(%dst : memref<?x?xf32, 1>, %src : memref<?x?xf32>) {
+    // CHECK: %[[T0:.*]] = llvm.call @mgpuStreamCreate
+    %t0 = gpu.wait async
+    %t1 = gpu.memcpy async [%t0] %dst, %src : memref<?x?xf32, 1>, memref<?x?xf32>
+    // CHECK: %[[DIM_SIZE_0:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    // CHECK-NEXT: %[[DIM_SIZE_1:.*]] = llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    // CHECK: %[[NUM_ELEMENTS:.*]] = llvm.mul %[[DIM_SIZE_0]], %[[DIM_SIZE_1]]  : i64
+    // CHECK: %[[SIZE_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[NUM_ELEMENTS]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: %[[SIZE_INT:.*]] = llvm.ptrtoint %[[SIZE_PTR]] : !llvm.ptr to i64
+    // CHECK: %[[ADDR_CAST:.*]] = llvm.addrspacecast
+    // CHECK: llvm.call @mgpuMemcpy(%[[ADDR_CAST]], %{{.*}}, %[[SIZE_INT]], %[[T0]])
+    return
+  }
+}


        


More information about the Mlir-commits mailing list