[Mlir-commits] [mlir] 2219f9f - [mlir][MemRef] Fix MemRefCopyOpLowering to use correct number of bytes
Adrian Kuegel
llvmlistbot at llvm.org
Fri Feb 11 04:59:24 PST 2022
Author: Adrian Kuegel
Date: 2022-02-11T13:59:08+01:00
New Revision: 2219f9f57cff2ecc0402b393630e0975f8873603
URL: https://github.com/llvm/llvm-project/commit/2219f9f57cff2ecc0402b393630e0975f8873603
DIFF: https://github.com/llvm/llvm-project/commit/2219f9f57cff2ecc0402b393630e0975f8873603.diff
LOG: [mlir][MemRef] Fix MemRefCopyOpLowering to use correct number of bytes
When lowering to memrefCopy call, the size for i1 type was calculated as 0.
Instead of using getTypeSizeInBits() and dividing by 8, we should just use getTypeSize().
Differential Revision: https://reviews.llvm.org/D119540
Added:
Modified:
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 4507b106b1758..a8910c2667a38 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -914,10 +914,10 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto sourcePtr = promote(unrankedSource);
auto targetPtr = promote(unrankedTarget);
- unsigned bitwidth = mlir::DataLayout::closest(op).getTypeSizeInBits(
- srcType.getElementType());
+ unsigned typeSize =
+ mlir::DataLayout::closest(op).getTypeSize(srcType.getElementType());
auto elemSize = rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(bitwidth / 8));
+ loc, getIndexType(), rewriter.getIndexAttr(typeSize));
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
rewriter.create<LLVM::CallOp>(loc, copyFn,
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index ee7d36052c4ae..2fc6905117623 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -911,3 +911,62 @@ func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
// CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32)
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: func @memref_copy_ranked
+func @memref_copy_ranked() {
+ %0 = memref.alloc() : memref<2xf32>
+ // CHECK: llvm.mlir.constant(2 : index) : i64
+ // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %1 = memref.cast %0 : memref<2xf32> to memref<?xf32>
+ %2 = memref.alloc() : memref<2xf32>
+ // CHECK: llvm.mlir.constant(2 : index) : i64
+ // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ %3 = memref.cast %2 : memref<2xf32> to memref<?xf32>
+ memref.copy %1, %3 : memref<?xf32> to memref<?xf32>
+ // CHECK: [[ONE:%.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: [[EXTRACT0:%.*]] = llvm.extractvalue {{%.*}}[3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[MUL:%.*]] = llvm.mul [[ONE]], [[EXTRACT0]] : i64
+ // CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr<f32>
+ // CHECK: [[ONE2:%.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]][[[ONE2]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+ // CHECK: [[PTRTOINT:%.*]] = llvm.ptrtoint [[GEP]] : !llvm.ptr<f32> to i64
+ // CHECK: [[SIZE:%.*]] = llvm.mul [[MUL]], [[PTRTOINT]] : i64
+ // CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[EXTRACT2:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[VOLATILE:%.*]] = llvm.mlir.constant(false) : i1
+ // CHECK: "llvm.intr.memcpy"([[EXTRACT2]], [[EXTRACT1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i1) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_copy_unranked
+func @memref_copy_unranked() {
+ %0 = memref.alloc() : memref<2xi1>
+ // CHECK: llvm.mlir.constant(2 : index) : i64
+ // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>
+ %1 = memref.cast %0 : memref<2xi1> to memref<*xi1>
+ %2 = memref.alloc() : memref<2xi1>
+ // CHECK: llvm.mlir.constant(2 : index) : i64
+ // CHECK: llvm.mlir.undef : !llvm.struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>
+ %3 = memref.cast %2 : memref<2xi1> to memref<*xi1>
+ memref.copy %1, %3 : memref<*xi1> to memref<*xi1>
+ // CHECK: [[ONE:%.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: [[ALLOCA:%.*]] = llvm.alloca %35 x !llvm.struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)> : (i64) -> !llvm.ptr<struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>>
+ // CHECK: llvm.store {{%.*}}, [[ALLOCA]] : !llvm.ptr<struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>>
+ // CHECK: [[BITCAST:%.*]] = llvm.bitcast [[ALLOCA]] : !llvm.ptr<struct<(ptr<i1>, ptr<i1>, i64, array<1 x i64>, array<1 x i64>)>> to !llvm.ptr<i8>
+ // CHECK: [[RANK:%.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: [[UNDEF:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
+ // CHECK: [[INSERT:%.*]] = llvm.insertvalue [[RANK]], [[UNDEF]][0] : !llvm.struct<(i64, ptr<i8>)>
+ // CHECK: [[INSERT2:%.*]] = llvm.insertvalue [[BITCAST]], [[INSERT]][1] : !llvm.struct<(i64, ptr<i8>)>
+ // CHECK: [[RANK2:%.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: [[ALLOCA2:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr<i8>)> : (i64) -> !llvm.ptr<struct<(i64, ptr<i8>)>>
+ // CHECK: llvm.store {{%.*}}, [[ALLOCA2]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
+ // CHECK: [[ALLOCA3:%.*]] = llvm.alloca [[RANK2]] x !llvm.struct<(i64, ptr<i8>)> : (i64) -> !llvm.ptr<struct<(i64, ptr<i8>)>>
+ // CHECK: llvm.store [[INSERT2]], [[ALLOCA3]] : !llvm.ptr<struct<(i64, ptr<i8>)>>
+ // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: llvm.call @memrefCopy([[SIZE]], [[ALLOCA2]], [[ALLOCA3]]) : (i64, !llvm.ptr<struct<(i64, ptr<i8>)>>, !llvm.ptr<struct<(i64, ptr<i8>)>>) -> ()
+ return
+}
More information about the Mlir-commits
mailing list