[Mlir-commits] [mlir] c336a06 - [mlir] [memref] Fix alignment bug in memref.copy lowering
Alex Zinenko
llvmlistbot at llvm.org
Thu Sep 14 04:18:19 PDT 2023
Author: Felix Schneider
Date: 2023-09-14T13:18:12+02:00
New Revision: c336a06144a7bc93156fa01d5489f8d738cdb590
URL: https://github.com/llvm/llvm-project/commit/c336a06144a7bc93156fa01d5489f8d738cdb590
DIFF: https://github.com/llvm/llvm-project/commit/c336a06144a7bc93156fa01d5489f8d738cdb590.diff
LOG: [mlir] [memref] Fix alignment bug in memref.copy lowering
memref.copy gets lowered to a function call sometimes, this function
is passed the element size of the memref in bytes as an argument.
The element size passed to the copyMemRef() function call can be
miscalculated if the LLVM IR uses aligned access to the memory.
This can be fixed by using llvm.getelementptr to calculate the element
size natively. This is also done in the other lowering path that lowers
to an intrinsic.
Fix https://github.com/llvm/llvm-project/issues/64072
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D156126
Added:
Modified:
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 159fa1da935700e..61bd23f12601c79 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -879,10 +879,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto sourcePtr = promote(unrankedSource);
auto targetPtr = promote(unrankedTarget);
- unsigned typeSize =
- mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
- auto elemSize = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(typeSize));
+ // Derive size from llvm.getelementptr which will account for any
+ // potential alignment
+ auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
rewriter.create<LLVM::CallOp>(loc, copyFn,
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 2ece4acc05f5d92..9e44029ad93bd9c 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -558,7 +558,8 @@ func.func @memref_copy_unranked() {
// CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.struct<(i64, ptr)>, !llvm.ptr
// CHECK: [[ALLOCA3:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr)> : (i64) -> !llvm.ptr
// CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.struct<(i64, ptr)>, !llvm.ptr
- // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: [[SIZEPTR:%.*]] = llvm.getelementptr {{%.*}}[1] : (!llvm.ptr) -> !llvm.ptr, i1
+ // CHECK: [[SIZE:%.*]] = llvm.ptrtoint [[SIZEPTR]] : !llvm.ptr to i64
// CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr, !llvm.ptr) -> ()
// CHECK: llvm.intr.stackrestore [[STACKSAVE]]
return
diff --git a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir
index 4b077b8e650dcdd..893c359b7b0718e 100644
--- a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir
@@ -82,7 +82,8 @@ func.func @memref_copy_unranked() {
// CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: [[ALLOCA3:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr<i8>)> : (i64) -> !llvm.ptr<struct<(i64, ptr<i8>)>>
// CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
- // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: [[SIZEPTR:%.*]] = llvm.getelementptr {{%.*}}[1] : (!llvm.ptr<i1>) -> !llvm.ptr<i1>
+ // CHECK: [[SIZE:%.*]] = llvm.ptrtoint [[SIZEPTR]] : !llvm.ptr<i1> to i64
// CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr<struct<(i64, ptr<i8>)>>, !llvm.ptr<struct<(i64, ptr<i8>)>>) -> ()
// CHECK: llvm.intr.stackrestore [[STACKSAVE]]
return
More information about the Mlir-commits
mailing list