[Mlir-commits] [mlir] ab95ba7 - [mlir][memref] Implement fast lowering of memref.copy
Stephan Herhut
llvmlistbot at llvm.org
Fri Jan 14 05:22:35 PST 2022
Author: Stephan Herhut
Date: 2022-01-14T14:22:15+01:00
New Revision: ab95ba704da458022f0fb3d7785a1d8b500b41b0
URL: https://github.com/llvm/llvm-project/commit/ab95ba704da458022f0fb3d7785a1d8b500b41b0
DIFF: https://github.com/llvm/llvm-project/commit/ab95ba704da458022f0fb3d7785a1d8b500b41b0.diff
LOG: [mlir][memref] Implement fast lowering of memref.copy
In the absence of maps, we can lower memref.copy to a memcpy.
Differential Revision: https://reviews.llvm.org/D116099
Added:
Modified:
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/mlir-cpu-runner/copy.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 3e3adb72cf4a7..0bc6eb923d164 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -706,12 +706,52 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
}
};
+/// Pattern to lower a `memref.copy` to llvm.
+///
+/// For memrefs with identity layouts, the copy is lowered to the llvm
+/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
+/// to the generic `MemrefCopyFn`.
struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
+ lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ auto loc = op.getLoc();
+ auto srcType = op.source().getType().dyn_cast<MemRefType>();
+
+ MemRefDescriptor srcDesc(adaptor.source());
+
+ // Compute number of elements.
+ Value numElements;
+ for (int pos = 0; pos < srcType.getRank(); ++pos) {
+ auto size = srcDesc.size(rewriter, loc, pos);
+ numElements = numElements
+ ? rewriter.create<LLVM::MulOp>(loc, numElements, size)
+ : size;
+ }
+ // Get element size.
+ auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
+ // Compute total.
+ Value totalSize =
+ rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
+
+ Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
+ MemRefDescriptor targetDesc(adaptor.target());
+ Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
+ Value isVolatile = rewriter.create<LLVM::ConstantOp>(
+ loc, typeConverter->convertType(rewriter.getI1Type()),
+ rewriter.getBoolAttr(false));
+ rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
+ isVolatile);
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+
+ LogicalResult
+ lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcType = op.source().getType().cast<BaseMemRefType>();
auto targetType = op.target().getType().cast<BaseMemRefType>();
@@ -765,6 +805,21 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
return success();
}
+
+ LogicalResult
+ matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcType = op.source().getType().cast<BaseMemRefType>();
+ auto targetType = op.target().getType().cast<BaseMemRefType>();
+
+ if (srcType.hasRank() &&
+ srcType.cast<MemRefType>().getLayout().isIdentity() &&
+ targetType.hasRank() &&
+ targetType.cast<MemRefType>().getLayout().isIdentity())
+ return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
+
+ return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
+ }
};
/// Extracts allocated, aligned pointers and offset from a ranked or unranked
diff --git a/mlir/test/mlir-cpu-runner/copy.mlir b/mlir/test/mlir-cpu-runner/copy.mlir
index e5a471fe204dc..8581f135bc851 100644
--- a/mlir/test/mlir-cpu-runner/copy.mlir
+++ b/mlir/test/mlir-cpu-runner/copy.mlir
@@ -35,7 +35,7 @@ func @main() -> () {
// CHECK-NEXT: [3, 4, 5]
%copy_two = memref.alloc() : memref<3x2xf32>
- %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2]
+ %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides: [1, 2]
: memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
%unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32>
@@ -49,6 +49,13 @@ func @main() -> () {
%copy_empty = memref.alloc() : memref<3x0x1xf32>
// Copying an empty shape should do nothing (and should not crash).
memref.copy %input_empty, %copy_empty : memref<3x0x1xf32> to memref<3x0x1xf32>
+
+ %input_empty_casted = memref.reinterpret_cast %input_empty to offset: [0], sizes: [0, 3, 1], strides: [3, 1, 1]
+ : memref<3x0x1xf32> to memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]>
+ %copy_empty_casted = memref.alloc() : memref<0x3x1xf32>
+ // Copying a casted empty shape should do nothing (and should not crash).
+ memref.copy %input_empty_casted, %copy_empty_casted : memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]> to memref<0x3x1xf32>
+
memref.dealloc %copy_empty : memref<3x0x1xf32>
memref.dealloc %input_empty : memref<3x0x1xf32>
memref.dealloc %copy_two : memref<3x2xf32>
More information about the Mlir-commits
mailing list