[Mlir-commits] [mlir] ab95ba7 - [mlir][memref] Implement fast lowering of memref.copy

Stephan Herhut llvmlistbot at llvm.org
Fri Jan 14 05:22:35 PST 2022


Author: Stephan Herhut
Date: 2022-01-14T14:22:15+01:00
New Revision: ab95ba704da458022f0fb3d7785a1d8b500b41b0

URL: https://github.com/llvm/llvm-project/commit/ab95ba704da458022f0fb3d7785a1d8b500b41b0
DIFF: https://github.com/llvm/llvm-project/commit/ab95ba704da458022f0fb3d7785a1d8b500b41b0.diff

LOG: [mlir][memref] Implement fast lowering of memref.copy

In the absence of maps, we can lower memref.copy to a memcpy.

Differential Revision: https://reviews.llvm.org/D116099

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/mlir-cpu-runner/copy.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 3e3adb72cf4a7..0bc6eb923d164 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -706,12 +706,52 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
   }
 };
 
+/// Pattern to lower a `memref.copy` to llvm.
+///
+/// For memrefs with identity layouts, the copy is lowered to the llvm
+/// `memcpy` intrinsic. For non-identity layouts, the copy is lowered to a call
+/// to the generic `MemrefCopyFn`.
 struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
   using ConvertOpToLLVMPattern<memref::CopyOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
+  lowerToMemCopyIntrinsic(memref::CopyOp op, OpAdaptor adaptor,
+                          ConversionPatternRewriter &rewriter) const {
+    auto loc = op.getLoc();
+    auto srcType = op.source().getType().dyn_cast<MemRefType>();
+
+    MemRefDescriptor srcDesc(adaptor.source());
+
+    // Compute number of elements.
+    Value numElements;
+    for (int pos = 0; pos < srcType.getRank(); ++pos) {
+      auto size = srcDesc.size(rewriter, loc, pos);
+      numElements = numElements
+                        ? rewriter.create<LLVM::MulOp>(loc, numElements, size)
+                        : size;
+    }
+    // Get element size.
+    auto sizeInBytes = getSizeInBytes(loc, srcType.getElementType(), rewriter);
+    // Compute total.
+    Value totalSize =
+        rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
+
+    Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
+    MemRefDescriptor targetDesc(adaptor.target());
+    Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
+    Value isVolatile = rewriter.create<LLVM::ConstantOp>(
+        loc, typeConverter->convertType(rewriter.getI1Type()),
+        rewriter.getBoolAttr(false));
+    rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
+                                    isVolatile);
+    rewriter.eraseOp(op);
+
+    return success();
+  }
+
+  LogicalResult
+  lowerToMemCopyFunctionCall(memref::CopyOp op, OpAdaptor adaptor,
+                             ConversionPatternRewriter &rewriter) const {
     auto loc = op.getLoc();
     auto srcType = op.source().getType().cast<BaseMemRefType>();
     auto targetType = op.target().getType().cast<BaseMemRefType>();
@@ -765,6 +805,21 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
 
     return success();
   }
+
+  LogicalResult
+  matchAndRewrite(memref::CopyOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcType = op.source().getType().cast<BaseMemRefType>();
+    auto targetType = op.target().getType().cast<BaseMemRefType>();
+
+    if (srcType.hasRank() &&
+        srcType.cast<MemRefType>().getLayout().isIdentity() &&
+        targetType.hasRank() &&
+        targetType.cast<MemRefType>().getLayout().isIdentity())
+      return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
+
+    return lowerToMemCopyFunctionCall(op, adaptor, rewriter);
+  }
 };
 
 /// Extracts allocated, aligned pointers and offset from a ranked or unranked

diff  --git a/mlir/test/mlir-cpu-runner/copy.mlir b/mlir/test/mlir-cpu-runner/copy.mlir
index e5a471fe204dc..8581f135bc851 100644
--- a/mlir/test/mlir-cpu-runner/copy.mlir
+++ b/mlir/test/mlir-cpu-runner/copy.mlir
@@ -35,7 +35,7 @@ func @main() -> () {
   // CHECK-NEXT: [3,   4,   5]
 
   %copy_two = memref.alloc() : memref<3x2xf32>
-  %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides:[1, 2]
+  %copy_two_casted = memref.reinterpret_cast %copy_two to offset: [0], sizes: [2, 3], strides: [1, 2]
     : memref<3x2xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
   memref.copy %input, %copy_two_casted : memref<2x3xf32> to memref<2x3xf32, offset: 0, strides: [1, 2]>
   %unranked_copy_two = memref.cast %copy_two : memref<3x2xf32> to memref<*xf32>
@@ -49,6 +49,13 @@ func @main() -> () {
   %copy_empty = memref.alloc() : memref<3x0x1xf32>
   // Copying an empty shape should do nothing (and should not crash).
   memref.copy %input_empty, %copy_empty : memref<3x0x1xf32> to memref<3x0x1xf32>
+
+  %input_empty_casted = memref.reinterpret_cast %input_empty to offset: [0], sizes: [0, 3, 1], strides: [3, 1, 1]
+    : memref<3x0x1xf32> to memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]>
+  %copy_empty_casted = memref.alloc() : memref<0x3x1xf32>
+  // Copying a casted empty shape should do nothing (and should not crash).
+  memref.copy %input_empty_casted, %copy_empty_casted : memref<0x3x1xf32, offset: 0, strides: [3, 1, 1]> to memref<0x3x1xf32>
+
   memref.dealloc %copy_empty : memref<3x0x1xf32>
   memref.dealloc %input_empty : memref<3x0x1xf32>
   memref.dealloc %copy_two : memref<3x2xf32>


        


More information about the Mlir-commits mailing list