[Mlir-commits] [mlir] 27cd2a6 - [mlir][MemRef] Lower memref.copy with an offset to memcpy

Benjamin Kramer llvmlistbot at llvm.org
Wed Feb 16 08:19:05 PST 2022


Author: Benjamin Kramer
Date: 2022-02-16T17:18:31+01:00
New Revision: 27cd2a6284b8c59f5dbd9086cf80db3b7b7047b1

URL: https://github.com/llvm/llvm-project/commit/27cd2a6284b8c59f5dbd9086cf80db3b7b7047b1
DIFF: https://github.com/llvm/llvm-project/commit/27cd2a6284b8c59f5dbd9086cf80db3b7b7047b1.diff

LOG: [mlir][MemRef] Lower memref.copy with an offset to memcpy

memcpy can handle them as long as they're contiguous.

Differential Revision: https://reviews.llvm.org/D119938

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index a8910c2667a38..56413c4155903 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -857,12 +857,18 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
         rewriter.create<LLVM::MulOp>(loc, numElements, sizeInBytes);
 
     Value srcBasePtr = srcDesc.alignedPtr(rewriter, loc);
+    Value srcOffset = srcDesc.offset(rewriter, loc);
+    Value srcPtr = rewriter.create<LLVM::GEPOp>(loc, srcBasePtr.getType(),
+                                                srcBasePtr, srcOffset);
     MemRefDescriptor targetDesc(adaptor.target());
     Value targetBasePtr = targetDesc.alignedPtr(rewriter, loc);
+    Value targetOffset = targetDesc.offset(rewriter, loc);
+    Value targetPtr = rewriter.create<LLVM::GEPOp>(loc, targetBasePtr.getType(),
+                                                   targetBasePtr, targetOffset);
     Value isVolatile = rewriter.create<LLVM::ConstantOp>(
         loc, typeConverter->convertType(rewriter.getI1Type()),
         rewriter.getBoolAttr(false));
-    rewriter.create<LLVM::MemcpyOp>(loc, targetBasePtr, srcBasePtr, totalSize,
+    rewriter.create<LLVM::MemcpyOp>(loc, targetPtr, srcPtr, totalSize,
                                     isVolatile);
     rewriter.eraseOp(op);
 
@@ -933,10 +939,18 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
     auto srcType = op.source().getType().cast<BaseMemRefType>();
     auto targetType = op.target().getType().cast<BaseMemRefType>();
 
-    if (srcType.hasRank() &&
-        srcType.cast<MemRefType>().getLayout().isIdentity() &&
-        targetType.hasRank() &&
-        targetType.cast<MemRefType>().getLayout().isIdentity())
+    auto isContiguousMemrefType = [](BaseMemRefType type) {
+      auto memrefType = type.dyn_cast<mlir::MemRefType>();
+      // We can use memcpy for memrefs if they have an identity layout or are
+      // contiguous with an arbitrary offset. Ignore empty memrefs, which is a
+      // special case handled by memrefCopy.
+      return memrefType &&
+             (memrefType.getLayout().isIdentity() ||
+              (memrefType.hasStaticShape() && memrefType.getNumElements() > 0 &&
+               isStaticShapeAndContiguousRowMajor(memrefType)));
+    };
+
+    if (isContiguousMemrefType(srcType) && isContiguousMemrefType(targetType))
       return lowerToMemCopyIntrinsic(op, adaptor, rewriter);
 
     return lowerToMemCopyFunctionCall(op, adaptor, rewriter);

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 2fc6905117623..85c6ca17747ba 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -933,10 +933,55 @@ func @memref_copy_ranked() {
   // CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]][[[ONE2]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
   // CHECK: [[PTRTOINT:%.*]] = llvm.ptrtoint [[GEP]] : !llvm.ptr<f32> to i64
   // CHECK: [[SIZE:%.*]] = llvm.mul [[MUL]], [[PTRTOINT]] : i64
-  // CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
-  // CHECK: [[EXTRACT2:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[EXTRACT1P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[EXTRACT1O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[GEP1:%.*]] = llvm.getelementptr [[EXTRACT1P]][[[EXTRACT1O]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+  // CHECK: [[EXTRACT2P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[EXTRACT2O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  // CHECK: [[GEP2:%.*]] = llvm.getelementptr [[EXTRACT2P]][[[EXTRACT2O]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
   // CHECK: [[VOLATILE:%.*]] = llvm.mlir.constant(false) : i1
-  // CHECK: "llvm.intr.memcpy"([[EXTRACT2]], [[EXTRACT1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i1) -> ()
+  // CHECK: "llvm.intr.memcpy"([[GEP2]], [[GEP1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<f32>, !llvm.ptr<f32>, i64, i1) -> ()
+  return
+}
+
+
+// -----
+
+// CHECK-LABEL: func @memref_copy_contiguous
+#map = affine_map<(d0, d1)[s0] -> (d0 * 2 + s0 + d1)>
+func @memref_copy_contiguous(%in: memref<16x2xi32>, %offset: index) {
+  %buf = memref.alloc() : memref<1x2xi32>
+  %sub = memref.subview %in[%offset, 0] [1, 2] [1, 1] : memref<16x2xi32> to memref<1x2xi32, #map>
+  memref.copy %sub, %buf : memref<1x2xi32, #map> to memref<1x2xi32>
+  // CHECK: [[EXTRACT0:%.*]] = llvm.extractvalue {{%.*}}[3, 0] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: [[MUL1:%.*]] = llvm.mul {{.*}}, [[EXTRACT0]] : i64
+  // CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue {{%.*}}[3, 1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: [[MUL2:%.*]] = llvm.mul [[MUL1]], [[EXTRACT1]] : i64
+  // CHECK: [[NULL:%.*]] = llvm.mlir.null : !llvm.ptr<i32>
+  // CHECK: [[ONE2:%.*]] = llvm.mlir.constant(1 : index) : i64
+  // CHECK: [[GEP:%.*]] = llvm.getelementptr [[NULL]][[[ONE2]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
+  // CHECK: [[PTRTOINT:%.*]] = llvm.ptrtoint [[GEP]] : !llvm.ptr<i32> to i64
+  // CHECK: [[SIZE:%.*]] = llvm.mul [[MUL2]], [[PTRTOINT]] : i64
+  // CHECK: [[EXTRACT1P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: [[EXTRACT1O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: [[GEP1:%.*]] = llvm.getelementptr [[EXTRACT1P]][[[EXTRACT1O]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
+  // CHECK: [[EXTRACT2P:%.*]] = llvm.extractvalue {{%.*}}[1] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: [[EXTRACT2O:%.*]] = llvm.extractvalue {{%.*}}[2] : !llvm.struct<(ptr<i32>, ptr<i32>, i64, array<2 x i64>, array<2 x i64>)>
+  // CHECK: [[GEP2:%.*]] = llvm.getelementptr [[EXTRACT2P]][[[EXTRACT2O]]] : (!llvm.ptr<i32>, i64) -> !llvm.ptr<i32>
+  // CHECK: [[VOLATILE:%.*]] = llvm.mlir.constant(false) : i1
+  // CHECK: "llvm.intr.memcpy"([[GEP2]], [[GEP1]], [[SIZE]], [[VOLATILE]]) : (!llvm.ptr<i32>, !llvm.ptr<i32>, i64, i1) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @memref_copy_noncontiguous
+#map = affine_map<(d0, d1)[s0] -> (d0 * 2 + s0 + d1)>
+func @memref_copy_noncontiguous(%in: memref<16x2xi32>, %offset: index) {
+  %buf = memref.alloc() : memref<2x1xi32>
+  %sub = memref.subview %in[%offset, 0] [2, 1] [1, 1] : memref<16x2xi32> to memref<2x1xi32, #map>
+  memref.copy %sub, %buf : memref<2x1xi32, #map> to memref<2x1xi32>
+  // CHECK: llvm.call @memrefCopy
   return
 }
 


        


More information about the Mlir-commits mailing list