[Mlir-commits] [mlir] c13ebb5 - Fix bug in gpu.memcpy lowering for dynamically shaped operands. (#128820)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 3 00:23:12 PST 2025
Author: Arnab Dutta
Date: 2025-03-03T13:53:09+05:30
New Revision: c13ebb527961e96e96ec1913dbbbcc6782512e18
URL: https://github.com/llvm/llvm-project/commit/c13ebb527961e96e96ec1913dbbbcc6782512e18
DIFF: https://github.com/llvm/llvm-project/commit/c13ebb527961e96e96ec1913dbbbcc6782512e18.diff
LOG: Fix bug in gpu.memcpy lowering for dynamically shaped operands. (#128820)
Compute the number of elements to be copied by multiplying dim sizes
along all the dimensions.
Added:
Modified:
mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 8017eb6bb383b..512820bab4097 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -76,14 +76,16 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
MemRefType type, MemRefDescriptor desc) const {
Type indexType = ConvertToLLVMPattern::getIndexType();
- return type.hasStaticShape()
- ? ConvertToLLVMPattern::createIndexAttrConstant(
- rewriter, loc, indexType, type.getNumElements())
- // For identity maps (verified by caller), the number of
- // elements is stride[0] * size[0].
- : rewriter.create<LLVM::MulOp>(loc,
- desc.stride(rewriter, loc, 0),
- desc.size(rewriter, loc, 0));
+ if (type.hasStaticShape())
+ return ConvertToLLVMPattern::createIndexAttrConstant(
+ rewriter, loc, indexType, type.getNumElements());
+ // Compute the number of elements by multiplying all the dim sizes.
+ uint64_t rank = type.getRank();
+ Value numElements = desc.size(rewriter, loc, /*pos=*/0);
+ for (unsigned i = 1; i < rank; i++)
+ numElements = rewriter.create<LLVM::MulOp>(
+ loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
+ return numElements;
}
MLIRContext *context = &this->getTypeConverter()->getContext();
diff --git a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
index 3f86b07698279..b45d188a77e3f 100644
--- a/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir
@@ -17,3 +17,23 @@ module attributes {gpu.container_module} {
return
}
}
+
+// -----
+
+module attributes {gpu.container_module} {
+
+ // CHECK: func @dynamic
+ func.func @dynamic(%dst : memref<?x?xf32, 1>, %src : memref<?x?xf32>) {
+ // CHECK: %[[T0:.*]] = llvm.call @mgpuStreamCreate
+ %t0 = gpu.wait async
+ %t1 = gpu.memcpy async [%t0] %dst, %src : memref<?x?xf32, 1>, memref<?x?xf32>
+ // CHECK: %[[DIM_SIZE_0:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK-NEXT: %[[DIM_SIZE_1:.*]] = llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: %[[NUM_ELEMENTS:.*]] = llvm.mul %[[DIM_SIZE_0]], %[[DIM_SIZE_1]] : i64
+ // CHECK: %[[SIZE_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[NUM_ELEMENTS]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
+ // CHECK-NEXT: %[[SIZE_INT:.*]] = llvm.ptrtoint %[[SIZE_PTR]] : !llvm.ptr to i64
+ // CHECK: %[[ADDR_CAST:.*]] = llvm.addrspacecast
+ // CHECK: llvm.call @mgpuMemcpy(%[[ADDR_CAST]], %{{.*}}, %[[SIZE_INT]], %[[T0]])
+ return
+ }
+}
More information about the Mlir-commits
mailing list