[Mlir-commits] [mlir] cad0850 - [mlir][memref] Lower copy of memrefs with outer size-1 dims to intrinsic memcpy.

Oleg Shyshkov llvmlistbot at llvm.org
Fri May 12 08:18:30 PDT 2023


Author: Oleg Shyshkov
Date: 2023-05-12T17:18:13+02:00
New Revision: cad08503b8d5ffc3834ab2f3e10f9cf44f6f0ee3

URL: https://github.com/llvm/llvm-project/commit/cad08503b8d5ffc3834ab2f3e10f9cf44f6f0ee3
DIFF: https://github.com/llvm/llvm-project/commit/cad08503b8d5ffc3834ab2f3e10f9cf44f6f0ee3.diff

LOG: [mlir][memref] Lower copy of memrefs with outer size-1 dims to intrinsic memcpy.

With this change, more `memref.copy` will be lowered to the efficient `memcpy`. For example,

```
memref.copy %subview, %alloc : memref<1x576xf32, strided<[704, 1]>> to memref<1x576xf32>
```

Differential Revision: https://reviews.llvm.org/D150448

Added: 
    

Modified: 
    mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
    mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 1a6e5a4e8dbd0..013baef3dc07c 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1064,13 +1064,23 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
       if (failed(getStridesAndOffset(type, strides, offset)))
         return false;
 
+      // MemRef is contiguous if outer dimensions are size-1 and inner
+      // dimensions have unit strides.
       int64_t runningStride = 1;
-      for (unsigned i = strides.size(); i > 0; --i) {
-        if (strides[i - 1] != runningStride)
-          return false;
-        runningStride *= type.getDimSize(i - 1);
+      int64_t curDim = strides.size() - 1;
+      // Finds all inner dimensions with unit strides.
+      while (curDim >= 0 && strides[curDim] == runningStride) {
+        runningStride *= type.getDimSize(curDim);
+        --curDim;
       }
-      return true;
+
+      // Check if other dimensions are size-1.
+      while (curDim >= 0 && type.getDimSize(curDim) == 1) {
+        --curDim;
+      }
+
+      // All dims are unit-strided or size-1.
+      return curDim < 0;
     };
 
     auto isContiguousMemrefType = [&](BaseMemRefType type) {

diff  --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 5b58198bcfd82..1c18f602a4c3b 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -455,15 +455,15 @@ func.func @memref_copy_ranked() {
 // -----
 
 // CHECK-LABEL: func @memref_copy_contiguous
-func.func @memref_copy_contiguous(%in: memref<16x2xi32>, %offset: index) {
+func.func @memref_copy_contiguous(%in: memref<16x4xi32>, %offset: index) {
   %buf = memref.alloc() : memref<1x2xi32>
-  %sub = memref.subview %in[%offset, 0] [1, 2] [1, 1] : memref<16x2xi32> to memref<1x2xi32, strided<[2, 1], offset: ?>>
-  memref.copy %sub, %buf : memref<1x2xi32, strided<[2, 1], offset: ?>> to memref<1x2xi32>
+  %sub = memref.subview %in[%offset, 0] [1, 2] [1, 1] : memref<16x4xi32> to memref<1x2xi32, strided<[4, 1], offset: ?>>
+  memref.copy %sub, %buf : memref<1x2xi32, strided<[4, 1], offset: ?>> to memref<1x2xi32>
   // Skip the memref descriptor of the alloc.
   // CHECK: llvm.insertvalue {{%.*}}, {{%.*}}[4, 1]
   // Get the memref for the subview.
-  // CHECK: %[[SUBVIEW:.*]] = memref.subview %{{.*}}[%{{.*}}, 0] [1, 2] [1, 1] : memref<16x2xi32> to memref<1x2xi32, strided<[2, 1], offset: ?>>
-  // CHECK: %[[DESC:.*]] = builtin.unrealized_conversion_cast %[[SUBVIEW]] : memref<1x2xi32, strided<[2, 1], offset: ?>> to !llvm.struct<(ptr
+  // CHECK: %[[SUBVIEW:.*]] = memref.subview %{{.*}}[%{{.*}}, 0] [1, 2] [1, 1] : memref<16x4xi32> to memref<1x2xi32, strided<[4, 1], offset: ?>>
+  // CHECK: %[[DESC:.*]] = builtin.unrealized_conversion_cast %[[SUBVIEW]] : memref<1x2xi32, strided<[4, 1], offset: ?>> to !llvm.struct<(ptr
   // CHECK: [[EXTRACT0:%.*]] = llvm.extractvalue %[[DESC]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
   // CHECK: [[MUL1:%.*]] = llvm.mul {{.*}}, [[EXTRACT0]] : i64
   // CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue %[[DESC]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>


        


More information about the Mlir-commits mailing list