[Mlir-commits] [mlir] c336a06 - [mlir] [memref] Fix alignment bug in memref.copy lowering

Alex Zinenko llvmlistbot at llvm.org
Thu Sep 14 04:18:19 PDT 2023


Author: Felix Schneider
Date: 2023-09-14T13:18:12+02:00
New Revision: c336a06144a7bc93156fa01d5489f8d738cdb590

URL: https://github.com/llvm/llvm-project/commit/c336a06144a7bc93156fa01d5489f8d738cdb590
DIFF: https://github.com/llvm/llvm-project/commit/c336a06144a7bc93156fa01d5489f8d738cdb590.diff

LOG: [mlir] [memref] Fix alignment bug in memref.copy lowering

memref.copy gets lowered to a function call sometimes, this function
is passed the element size of the memref in bytes as an argument.
The element size passed to the copyMemRef() function call can be
miscalculated if the LLVM IR uses aligned access to the memory.

This can be fixed by using llvm.getelementptr to calculate the element
size natively. This is also done in the other lowering path that lowers
to an intrinsic.

Fix https://github.com/llvm/llvm-project/issues/64072

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D156126

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
    mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 159fa1da935700e..61bd23f12601c79 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -879,10 +879,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
     auto sourcePtr = promote(unrankedSource);
     auto targetPtr = promote(unrankedTarget);
 
-    unsigned typeSize =
-        mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
-    auto elemSize = rewriter.create<LLVM::ConstantOp>(
-        loc, getIndexType(), rewriter.getIndexAttr(typeSize));
+    // Derive size from llvm.getelementptr which will account for any
+    // potential alignment
+    auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
     rewriter.create<LLVM::CallOp>(loc, copyFn,

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 2ece4acc05f5d92..9e44029ad93bd9c 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -558,7 +558,8 @@ func.func @memref_copy_unranked() {
   // CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.struct<(i64, ptr)>, !llvm.ptr
   // CHECK: [[ALLOCA3:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr)> : (i64) -> !llvm.ptr
   // CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.struct<(i64, ptr)>, !llvm.ptr
-  // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[SIZEPTR:%.*]] = llvm.getelementptr {{%.*}}[1] : (!llvm.ptr) -> !llvm.ptr, i1
+  // CHECK: [[SIZE:%.*]] = llvm.ptrtoint [[SIZEPTR]] : !llvm.ptr to i64
   // CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr, !llvm.ptr) -> ()
   // CHECK: llvm.intr.stackrestore [[STACKSAVE]]
   return

diff  --git a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir
index 4b077b8e650dcdd..893c359b7b0718e 100644
--- a/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/typed-pointers.mlir
@@ -82,7 +82,8 @@ func.func @memref_copy_unranked() {
   // CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
   // CHECK: [[ALLOCA3:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr<i8>)> : (i64) -> !llvm.ptr<struct<(i64, ptr<i8>)>>
   // CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
-  // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[SIZEPTR:%.*]] = llvm.getelementptr {{%.*}}[1] : (!llvm.ptr<i1>) -> !llvm.ptr<i1>
+  // CHECK: [[SIZE:%.*]] = llvm.ptrtoint [[SIZEPTR]] : !llvm.ptr<i1> to i64
   // CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr<struct<(i64, ptr<i8>)>>, !llvm.ptr<struct<(i64, ptr<i8>)>>) -> ()
   // CHECK: llvm.intr.stackrestore [[STACKSAVE]]
   return


        


More information about the Mlir-commits mailing list