[Mlir-commits] [mlir] cad0850 - [mlir][memref] Lower copy of memrefs with outer size-1 dims to intrinsic memcpy.
Oleg Shyshkov
llvmlistbot at llvm.org
Fri May 12 08:18:30 PDT 2023
Author: Oleg Shyshkov
Date: 2023-05-12T17:18:13+02:00
New Revision: cad08503b8d5ffc3834ab2f3e10f9cf44f6f0ee3
URL: https://github.com/llvm/llvm-project/commit/cad08503b8d5ffc3834ab2f3e10f9cf44f6f0ee3
DIFF: https://github.com/llvm/llvm-project/commit/cad08503b8d5ffc3834ab2f3e10f9cf44f6f0ee3.diff
LOG: [mlir][memref] Lower copy of memrefs with outer size-1 dims to intrinsic memcpy.
With this change, more `memref.copy` will be lowered to the efficient `memcpy`. For example,
```
memref.copy %subview, %alloc : memref<1x576xf32, strided<[704, 1]>> to memref<1x576xf32>
```
Differential Revision: https://reviews.llvm.org/D150448
Added:
Modified:
mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 1a6e5a4e8dbd0..013baef3dc07c 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1064,13 +1064,23 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
if (failed(getStridesAndOffset(type, strides, offset)))
return false;
+ // MemRef is contiguous if outer dimensions are size-1 and inner
+ // dimensions have unit strides.
int64_t runningStride = 1;
- for (unsigned i = strides.size(); i > 0; --i) {
- if (strides[i - 1] != runningStride)
- return false;
- runningStride *= type.getDimSize(i - 1);
+ int64_t curDim = strides.size() - 1;
+ // Finds all inner dimensions with unit strides.
+ while (curDim >= 0 && strides[curDim] == runningStride) {
+ runningStride *= type.getDimSize(curDim);
+ --curDim;
}
- return true;
+
+ // Check if other dimensions are size-1.
+ while (curDim >= 0 && type.getDimSize(curDim) == 1) {
+ --curDim;
+ }
+
+ // All dims are unit-strided or size-1.
+ return curDim < 0;
};
auto isContiguousMemrefType = [&](BaseMemRefType type) {
diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
index 5b58198bcfd82..1c18f602a4c3b 100644
--- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
+++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
@@ -455,15 +455,15 @@ func.func @memref_copy_ranked() {
// -----
// CHECK-LABEL: func @memref_copy_contiguous
-func.func @memref_copy_contiguous(%in: memref<16x2xi32>, %offset: index) {
+func.func @memref_copy_contiguous(%in: memref<16x4xi32>, %offset: index) {
%buf = memref.alloc() : memref<1x2xi32>
- %sub = memref.subview %in[%offset, 0] [1, 2] [1, 1] : memref<16x2xi32> to memref<1x2xi32, strided<[2, 1], offset: ?>>
- memref.copy %sub, %buf : memref<1x2xi32, strided<[2, 1], offset: ?>> to memref<1x2xi32>
+ %sub = memref.subview %in[%offset, 0] [1, 2] [1, 1] : memref<16x4xi32> to memref<1x2xi32, strided<[4, 1], offset: ?>>
+ memref.copy %sub, %buf : memref<1x2xi32, strided<[4, 1], offset: ?>> to memref<1x2xi32>
// Skip the memref descriptor of the alloc.
// CHECK: llvm.insertvalue {{%.*}}, {{%.*}}[4, 1]
// Get the memref for the subview.
- // CHECK: %[[SUBVIEW:.*]] = memref.subview %{{.*}}[%{{.*}}, 0] [1, 2] [1, 1] : memref<16x2xi32> to memref<1x2xi32, strided<[2, 1], offset: ?>>
- // CHECK: %[[DESC:.*]] = builtin.unrealized_conversion_cast %[[SUBVIEW]] : memref<1x2xi32, strided<[2, 1], offset: ?>> to !llvm.struct<(ptr
+ // CHECK: %[[SUBVIEW:.*]] = memref.subview %{{.*}}[%{{.*}}, 0] [1, 2] [1, 1] : memref<16x4xi32> to memref<1x2xi32, strided<[4, 1], offset: ?>>
+ // CHECK: %[[DESC:.*]] = builtin.unrealized_conversion_cast %[[SUBVIEW]] : memref<1x2xi32, strided<[4, 1], offset: ?>> to !llvm.struct<(ptr
// CHECK: [[EXTRACT0:%.*]] = llvm.extractvalue %[[DESC]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[MUL1:%.*]] = llvm.mul {{.*}}, [[EXTRACT0]] : i64
// CHECK: [[EXTRACT1:%.*]] = llvm.extractvalue %[[DESC]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
More information about the Mlir-commits
mailing list