[Mlir-commits] [mlir] Fix bug in gpu.memcpy lowering for dynamically shaped operands. (PR #128820)

Arnab Dutta llvmlistbot at llvm.org
Wed Feb 26 20:47:20 PST 2025


https://github.com/arnab-polymage updated https://github.com/llvm/llvm-project/pull/128820

>From 88d79e815a4be9812ad731622264d0276a548098 Mon Sep 17 00:00:00 2001
From: Arnab Dutta <arnab at polymagelabs.com>
Date: Wed, 26 Feb 2025 11:01:16 +0530
Subject: [PATCH] Fix bug in gpu.memcpy lowering for dynamically shaped
 operands.

Compute the number of elements to be copied by multiplying dim sizes
along all the dimensions.
---
 .../GPUCommon/GPUToLLVMConversion.cpp         | 18 +++++++++--------
 .../lower-memcpy-to-gpu-runtime-calls.mlir    | 20 +++++++++++++++++++
 2 files changed, 30 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 8017eb6bb383b..512820bab4097 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -76,14 +76,16 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
                        MemRefType type, MemRefDescriptor desc) const {
     Type indexType = ConvertToLLVMPattern::getIndexType();
-    return type.hasStaticShape()
-               ? ConvertToLLVMPattern::createIndexAttrConstant(
-                     rewriter, loc, indexType, type.getNumElements())
-               // For identity maps (verified by caller), the number of
-               // elements is stride[0] * size[0].
-               : rewriter.create<LLVM::MulOp>(loc,
-                                              desc.stride(rewriter, loc, 0),
-                                              desc.size(rewriter, loc, 0));
+    if (type.hasStaticShape())
+      return ConvertToLLVMPattern::createIndexAttrConstant(
+          rewriter, loc, indexType, type.getNumElements());
+    // Compute the number of elements by multiplying all the dim sizes.
+    uint64_t rank = type.getRank();
+    Value numElements = desc.size(rewriter, loc, /*pos=*/0);
+    for (unsigned i = 1; i < rank; i++)
+      numElements = rewriter.create<LLVM::MulOp>(
+          loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
+    return numElements;
   }
 
   MLIRContext *context = &this->getTypeConverter()->getContext();
diff --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
index 3f86b07698279..b45d188a77e3f 100644
--- a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
@@ -17,3 +17,23 @@ module attributes {gpu.container_module} {
     return
   }
 }
+
+// -----
+
+module attributes {gpu.container_module} {
+
+  // CHECK: func @dynamic
+  func.func @dynamic(%dst : memref<?x?xf32, 1>, %src : memref<?x?xf32>) {
+    // CHECK: %[[T0:.*]] = llvm.call @mgpuStreamCreate
+    %t0 = gpu.wait async
+    %t1 = gpu.memcpy async [%t0] %dst, %src : memref<?x?xf32, 1>, memref<?x?xf32>
+    // CHECK: %[[DIM_SIZE_0:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    // CHECK-NEXT: %[[DIM_SIZE_1:.*]] = llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> 
+    // CHECK: %[[NUM_ELEMENTS:.*]] = llvm.mul %[[DIM_SIZE_0]], %[[DIM_SIZE_1]]  : i64
+    // CHECK: %[[SIZE_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[NUM_ELEMENTS]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+    // CHECK-NEXT: %[[SIZE_INT:.*]] = llvm.ptrtoint %[[SIZE_PTR]] : !llvm.ptr to i64
+    // CHECK: %[[ADDR_CAST:.*]] = llvm.addrspacecast
+    // CHECK: llvm.call @mgpuMemcpy(%[[ADDR_CAST]], %{{.*}}, %[[SIZE_INT]], %[[T0]])
+    return
+  }
+}



More information about the Mlir-commits mailing list