[Mlir-commits] [mlir] 27cd2a6 - [mlir][MemRef] Lower memref.copy with an offset to memcpy
Benjamin Kramer
llvmlistbot at llvm.org
Wed Feb 16 08:19:05 PST 2022
Author: Benjamin Kramer
Date: 2022-02-16T17:18:31+01:00
New Revision: 27cd2a6284b8c59f5dbd9086cf80db3b7b7047b1
URL: https://github.com/llvm/llvm-project/commit/27cd2a6284b8c59f5dbd9086cf80db3b7b7047b1
DIFF: https://github.com/llvm/llvm-project/commit/27cd2a6284b8c59f5dbd9086cf80db3b7b7047b1.diff
LOG: [mlir][MemRef] Lower memref.copy with an offset to memcpy
memcpy can handle them as long as they're contiguous.
Differential Revision: https://reviews.llvm.org/D119938
Added:
Modified:
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index a8910c2667a38..56413c4155903 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -857,12 +857,18 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
+ Value srcOffset = srcDesc.offset(rewriter, loc);
+ Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
+ srcBasePtr, srcOffset);
MemRefDescriptor targetDesc(adaptor.target());
Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
+ Value targetOffset = targetDesc.offset(rewriter, loc);
+ Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
+ targetBasePtr, targetOffset);
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getI1Type()),
rewriter.getBoolAttr(false));
- rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
+ rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
isVolatile);
rewriter.eraseOp(op);
@@ -933,10 +939,18 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto srcType = op.source().getType().cast<BaseMemRefType>();
auto targetType = op.target().getType().cast<BaseMemRefType>();
- if (srcType.hasRank() &&
- srcType.cast<MemRefType>().getLayout().isIdentity() &&
- targetType.hasRank() &&
- targetType.cast<MemRefType>().getLayout().isIdentity())
+ auto isContiguousMemrefType = [](BaseMemRefType type) {
+ auto memrefType = type.dyn_cast<mlir::MemRefType>();
+ // We can use memcpy for memrefs if they have an identity layout or are
+ // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
+ // special case handled by memrefCopy.
+ return memrefType &&
+ (memrefType.getLayout().isIdentity() ||
+ (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
+ isStaticShapeAndContiguousRowMajor(memrefType)));
+ };
+
+ if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 2fc6905117623..85c6ca17747ba 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -933,10 +933,55 @@ func @memref_copy_ranked() {
// CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]][[[ONE2]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: [[PTRTOINT:%.*]] = llvm.ptrtoint [[GEP]] : !llvm.ptr<f32> to i64
// CHECK: [[SIZE:%.*]] = llvm.mul [[MUL]], [[PTRTOINT]] : i64
- // CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: [[EXTRACT2:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[EXTRACT1P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[EXTRACT1O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[GEP1:%.*]] = llvm.getelementptr [[EXTRACT1P]][[[EXTRACT1O]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+ // CHECK: [[EXTRACT2P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[EXTRACT2O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: [[GEP2:%.*]] = llvm.getelementptr [[EXTRACT2P]][[[EXTRACT2O]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
// CHECK: [[VOLATILE:%.*]] = llvm.mlir.constant(false) : i1
- // CHECK: "llvm.intr.memcpy"([[EXTRACT2]], [[EXTRACT1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i1) -> ()
+ // CHECK: "llvm.intr.memcpy"([[GEP2]], [[GEP1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i1) -> ()
+ return
+}
+
+
+// -----
+
+// CHECK-LABEL: func @memref_copy_contiguous
+#map = affine_map<(d0, d1)[s0] -> (d0 * 2 + s0 + d1)>
+func @memref_copy_contiguous(%in: memref<16x2xi32>, %offset: index) {
+ %buf = memref.alloc() : memref<1x2xi32>
+ %sub = memref.subview %in[%offset, 0] [1, 2] [1, 1] : memref<16x2xi32> to memref<1x2xi32, #map>
+ memref.copy %sub, %buf : memref<1x2xi32, #map> to memref<1x2xi32>
+ // CHECK: [[EXTRACT0:%.*]] = llvm.extractvalue {{%.*}}[3, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: [[MUL1:%.*]] = llvm.mul {{.*}}, [[EXTRACT0]] : i64
+ // CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue {{%.*}}[3, 1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: [[MUL2:%.*]] = llvm.mul [[MUL1]], [[EXTRACT1]] : i64
+ // CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr<i32>
+ // CHECK: [[ONE2:%.*]] = llvm.mlir.constant(1 : index) : i64
+ // CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]][[[ONE2]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
+ // CHECK: [[PTRTOINT:%.*]] = llvm.ptrtoint [[GEP]] : !llvm.ptr<i32> to i64
+ // CHECK: [[SIZE:%.*]] = llvm.mul [[MUL2]], [[PTRTOINT]] : i64
+ // CHECK: [[EXTRACT1P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: [[EXTRACT1O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: [[GEP1:%.*]] = llvm.getelementptr [[EXTRACT1P]][[[EXTRACT1O]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
+ // CHECK: [[EXTRACT2P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: [[EXTRACT2O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+ // CHECK: [[GEP2:%.*]] = llvm.getelementptr [[EXTRACT2P]][[[EXTRACT2O]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
+ // CHECK: [[VOLATILE:%.*]] = llvm.mlir.constant(false) : i1
+ // CHECK: "llvm.intr.memcpy"([[GEP2]], [[GEP1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i32>, !llvm.ptr<i32>, i64, i1) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_copy_noncontiguous
+#map = affine_map<(d0, d1)[s0] -> (d0 * 2 + s0 + d1)>
+func @memref_copy_noncontiguous(%in: memref<16x2xi32>, %offset: index) {
+ %buf = memref.alloc() : memref<2x1xi32>
+ %sub = memref.subview %in[%offset, 0] [2, 1] [1, 1] : memref<16x2xi32> to memref<2x1xi32, #map>
+ memref.copy %sub, %buf : memref<2x1xi32, #map> to memref<2x1xi32>
+ // CHECK: llvm.call @memrefCopy
return
}
More information about the Mlir-commits
mailing list