[Mlir-commits] [mlir] 2219f9f - [mlir][MemRef] Fix MemRefCopyOpLowering to use correct number of bytes

Adrian Kuegel llvmlistbot at llvm.org
Fri Feb 11 04:59:24 PST 2022


Author: Adrian Kuegel
Date: 2022-02-11T13:59:08+01:00
New Revision: 2219f9f57cff2ecc0402b393630e0975f8873603

URL: https://github.com/llvm/llvm-project/commit/2219f9f57cff2ecc0402b393630e0975f8873603
DIFF: https://github.com/llvm/llvm-project/commit/2219f9f57cff2ecc0402b393630e0975f8873603.diff

LOG: [mlir][MemRef] Fix MemRefCopyOpLowering to use correct number of bytes

When lowering to memrefCopy call, the size for i1 type was calculated as 0.
Instead of using getTypeSizeInBits() and dividing by 8, we should just use getTypeSize().

Differential Revision: https://reviews.llvm.org/D119540

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4507b106b1758..a8910c2667a38 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -914,10 +914,10 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
     auto sourcePtr = promote(unrankedSource);
     auto targetPtr = promote(unrankedTarget);
 
-    unsigned bitwidth = mlir::DataLayout::closest(op).getTypeSizeInBits(
-        srcType.getElementType());
+    unsigned typeSize =
+        mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
     auto elemSize = rewriter.create<LLVM::ConstantOp>(
-        loc, getIndexType(), rewriter.getIndexAttr(bitwidth / 8));
+        loc, getIndexType(), rewriter.getIndexAttr(typeSize));
     auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
         op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
     rewriter.create<LLVM::CallOp>(loc, copyFn,

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index ee7d36052c4ae..2fc6905117623 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -911,3 +911,62 @@ func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
   // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32)
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: func @memref_copy_ranked
+func @memref_copy_ranked() {
+  %0 = memref.alloc() : memref<2xf32>
+  // CHECK: llvm.mlir.constant(2 : index) : i64
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  %1 = memref.cast %0 : memref<2xf32> to memref<?xf32>
+  %2 = memref.alloc() : memref<2xf32>
+  // CHECK: llvm.mlir.constant(2 : index) : i64
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  %3 = memref.cast %2 : memref<2xf32> to memref<?xf32>
+  memref.copy %1, %3 : memref<?xf32> to memref<?xf32>
+  // CHECK: [[ONE:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[EXTRACT0:%.*]] = llvm.extractvalue {{%.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[MUL:%.*]] = llvm.mul [[ONE]], [[EXTRACT0]] : i64
+  // CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr<f32>
+  // CHECK: [[ONE2:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]][[[ONE2]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+  // CHECK: [[PTRTOINT:%.*]] = llvm.ptrtoint [[GEP]] : !llvm.ptr<f32> to i64
+  // CHECK: [[SIZE:%.*]] = llvm.mul [[MUL]], [[PTRTOINT]] : i64
+  // CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[EXTRACT2:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[VOLATILE:%.*]] = llvm.mlir.constant(false) : i1
+  // CHECK: "llvm.intr.memcpy"([[EXTRACT2]], [[EXTRACT1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i1) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_copy_unranked
+func @memref_copy_unranked() {
+  %0 = memref.alloc() : memref<2xi1>
+  // CHECK: llvm.mlir.constant(2 : index) : i64
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>
+  %1 = memref.cast %0 : memref<2xi1> to memref<*xi1>
+  %2 = memref.alloc() : memref<2xi1>
+  // CHECK: llvm.mlir.constant(2 : index) : i64
+  // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>
+  %3 = memref.cast %2 : memref<2xi1> to memref<*xi1>
+  memref.copy %1, %3 : memref<*xi1> to memref<*xi1>
+  // CHECK: [[ONE:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[ALLOCA:%.*]] = llvm.alloca %35 x !llvm.struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>>
+  // CHECK: llvm.store {{%.*}}, [[ALLOCA]] : !llvm.ptr<struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>>
+  // CHECK: [[BITCAST:%.*]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>> to !llvm.ptr<i8>
+  // CHECK: [[RANK:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[UNDEF:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
+  // CHECK: [[INSERT:%.*]] = llvm.insertvalue [[RANK]], [[UNDEF]][0] : !llvm.struct<(i64, ptr<i8>)>
+  // CHECK: [[INSERT2:%.*]] = llvm.insertvalue [[BITCAST]], [[INSERT]][1] : !llvm.struct<(i64, ptr<i8>)>
+  // CHECK: [[RANK2:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[ALLOCA2:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr<i8>)> : (i64) -> !llvm.ptr<struct<(i64, ptr<i8>)>>
+  // CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
+  // CHECK: [[ALLOCA3:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr<i8>)> : (i64) -> !llvm.ptr<struct<(i64, ptr<i8>)>>
+  // CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
+  // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr<struct<(i64, ptr<i8>)>>, !llvm.ptr<struct<(i64, ptr<i8>)>>) -> ()
+  return
+}


        


More information about the Mlir-commits mailing list