[Mlir-commits] [mlir] 40f5f3d - [mlir][linalg][bufferize] Use memref.copy instead of linalg.copy
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 14 05:29:18 PST 2022
Author: Matthias Springer
Date: 2022-01-14T22:29:05+09:00
New Revision: 40f5f3d62dcd7b5e74911431c0974c762836b80f
URL: https://github.com/llvm/llvm-project/commit/40f5f3d62dcd7b5e74911431c0974c762836b80f
DIFF: https://github.com/llvm/llvm-project/commit/40f5f3d62dcd7b5e74911431c0974c762836b80f.diff
LOG: [mlir][linalg][bufferize] Use memref.copy instead of linalg.copy
Differential Revision: https://reviews.llvm.org/D117220
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index e92f852cd6c8f..f0f1beb53ab01 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -81,11 +81,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
options->allocationFns->deallocationFn = [](OpBuilder &b, Location loc,
Value v) {};
}
- // TODO: Change to memref::CopyOp (default memCpyFn).
- options->allocationFns->memCpyFn = [](OpBuilder &b, Location loc, Value from,
- Value to) {
- b.create<linalg::CopyOp>(loc, from, to);
- };
options->allowReturnMemref = allowReturnMemref;
options->allowUnknownOps = allowUnknownOps;
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
index d30ab5ac4f9a9..de1b8321ed5eb 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
@@ -26,12 +26,12 @@ func @buffer_forwarding_conflict(
// CHECK: linalg.fill({{.*}}, %[[EXTRACT_SLICE_ALLOC]]) : f32, memref<?xf32>
%f = linalg.fill(%f0, %a) : f32, tensor<?xf32> -> tensor<?xf32>
- // CHECK: linalg.copy(%[[FUNC_ARG]], %[[ALLOC]]) : memref<?xf32>, memref<?xf32>
+ // CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref<?xf32> to memref<?xf32>
// CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref<?xf32> to memref<?xf32>
- // CHECK: linalg.copy(%[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]]) : memref<?xf32>, memref<?xf32>
+ // CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref<?xf32> to memref<?xf32>
%r0 = tensor.insert_slice %f into %t[0][%sz][1]: tensor<?xf32> into tensor<?xf32>
- // CHECK: linalg.copy(%[[EXTRACT_SLICE_ALLOC]], %[[T_SUBVIEW]])
+ // CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[T_SUBVIEW]]
%r1 = tensor.insert_slice %f into %t[42][%sz][1]: tensor<?xf32> into tensor<?xf32>
return %r0, %r1: tensor<?xf32>, tensor<?xf32>
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
index e8f484c263261..909c41cef97f2 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
@@ -141,7 +141,7 @@ func @unknown_op_not_writable(
// introducing a RaW conflict.
// CHECK: %[[dim:.*]] = tensor.dim %[[dummy]]
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
- // CHECK: linalg.copy(%[[dummy_memref]], %[[alloc]])
+ // CHECK: memref.copy %[[dummy_memref]], %[[alloc]]
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
%1 = vector.transfer_write %v, %0[%idx] : vector<5xf32>, tensor<?xf32>
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index 753e099354cfa..b37c3066a0d65 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -150,7 +150,7 @@ func @vec_not_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vec
/// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc.
// CHECK: %[[ALLOC:.*]] = memref.alloc
- // CHECK: linalg.copy({{.*}}, %[[ALLOC]])
+ // CHECK: memref.copy {{.*}}, %[[ALLOC]]
// CHECK-NEXT: vector.transfer_write {{.*}}, %[[ALLOC]]
%r0 = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor<?xf32>
@@ -185,27 +185,27 @@ func @insert_slice_fun(%A0 : tensor<?xf32> {linalg.inplaceable = false},
// CHECK: %[[REALLOC1:.*]] = memref.alloc
// Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
- // CHECK: linalg.copy(%[[A0]], %[[REALLOC3]]
+ // CHECK: memref.copy %[[A0]], %[[REALLOC3]]
// CHECK: %[[SV_A0:.*]] = memref.subview %[[REALLOC3]]
- // CHECK: linalg.copy(%[[t0]], %[[SV_A0]])
+ // CHECK: memref.copy %[[t0]], %[[SV_A0]]
%r0 = tensor.insert_slice %t0 into %A0[0][4][1] : tensor<4xf32> into tensor<?xf32>
// Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
- // CHECK: linalg.copy(%[[A0]]
+ // CHECK: memref.copy %[[A0]]
// CHECK: %[[SV_A0_2:.*]] = memref.subview %[[REALLOC2]]
- // CHECK: linalg.copy(%[[t1]], %[[SV_A0_2]])
+ // CHECK: memref.copy %[[t1]], %[[SV_A0_2]]
%r1 = tensor.insert_slice %t1 into %A0[0][4][1] : tensor<4xf32> into tensor<?xf32>
// Still alloc the large tensor because %A1 is read after. Copy the tensor.extract_slice.
- // CHECK: linalg.copy(%[[A1]]
+ // CHECK: memref.copy %[[A1]]
// CHECK: %[[SV_A1:.*]] = memref.subview %[[REALLOC1]]
- // CHECK: linalg.copy(%[[t0]], %[[SV_A1]])
+ // CHECK: memref.copy %[[t0]], %[[SV_A1]]
%r2 = tensor.insert_slice %t0 into %A1[0][4][1] : tensor<4xf32> into tensor<?xf32>
// Do not realloc the large tensor. Copy the tensor.extract_slice.
// CHECK-NOT: alloc
// CHECK: %[[SV_A1_2:.*]] = memref.subview %[[A1]]
- // CHECK: linalg.copy(%[[t1]], %[[SV_A1_2]])
+ // CHECK: memref.copy %[[t1]], %[[SV_A1_2]]
%r3 = tensor.insert_slice %t1 into %A1[0][4][1] : tensor<4xf32> into tensor<?xf32>
// CHECK: return %[[REALLOC3]], %[[REALLOC2]], %[[REALLOC1]] :
@@ -229,7 +229,7 @@ func @insert_slice_fun(
// CHECK-NOT: alloc
// CHECK: %[[SV_A:.*]] = memref.subview %[[A]]
- // CHECK: linalg.copy(%[[t]], %[[SV_A]])
+ // CHECK: memref.copy %[[t]], %[[SV_A]]
%r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
/// Overwrite A inplace.
@@ -261,7 +261,7 @@ func @insert_slice_fun(
// CHECK-NOT: alloc
// CHECK: %[[SV_A:.*]] = memref.subview %[[A]]
/// Overwrite A inplace by copying into the subview.
- // CHECK: linalg.copy(%[[t]], %[[SV_A]])
+ // CHECK: memref.copy %[[t]], %[[SV_A]]
%r1 = tensor.insert_slice %t into %r0[0][4][1] : tensor<4xf32> into tensor<?xf32>
// CHECK: return
@@ -282,9 +282,9 @@ func @insert_slice_fun_not_inplace(
-> tensor<?xf32>
{
// CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) {alignment = 128 : i64} : memref<?xf32>
- // CHECK: linalg.copy(%[[A]], %[[ALLOC]]) : memref<?xf32{{.*}}, memref<?xf32>
+ // CHECK: memref.copy %[[A]], %[[ALLOC]] : memref<?xf32{{.*}} to memref<?xf32>
// CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref<?xf32> to memref<4xf32>
- // CHECK: linalg.copy(%[[t]], %[[SV]]) : memref<4xf32, #map>, memref<4xf32>
+ // CHECK: memref.copy %[[t]], %[[SV]] : memref<4xf32, #map> to memref<4xf32>
// CHECK: memref.dealloc %[[ALLOC]] : memref<?xf32>
%r0 = tensor.insert_slice %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
@@ -310,7 +310,7 @@ func @scf_for_yield_only(%A : tensor<?xf32> {linalg.inplaceable = false},
{
// CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc
// CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]]
- // CHECK: linalg.copy(%[[A]], %[[ALLOC_FOR_A]])
+ // CHECK: memref.copy %[[A]], %[[ALLOC_FOR_A]]
// The first scf.for remains but just turns into dead code.
%r0 = scf.for %i = %lb to %ub step %step iter_args(%t = %A) -> (tensor<?xf32>) {
@@ -366,7 +366,7 @@ func @scf_for_with_tensor.insert_slice(
{
// CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc
// CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]]
- // CHECK: linalg.copy(%[[A]], %[[ALLOC_FOR_A]])
+ // CHECK: memref.copy %[[A]], %[[ALLOC_FOR_A]]
// CHECK: %[[svA:.*]] = memref.subview %[[ALLOC_FOR_A]][0] [4] [1]
// CHECK: %[[svB:.*]] = memref.subview %[[B]][0] [4] [1]
@@ -377,11 +377,11 @@ func @scf_for_with_tensor.insert_slice(
-> (tensor<?xf32>, tensor<?xf32>)
{
// %ttA bufferizes to direct copy of %BUFFER_CAST_C into %svA
- // CHECK: linalg.copy(%[[C]], %[[svA]])
+ // CHECK: memref.copy %[[C]], %[[svA]]
%ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor<?xf32>
// %ttB bufferizes to direct copy of %BUFFER_CAST_C into %BUFFER_CAST_B
- // CHECK: linalg.copy(%[[C]], %[[svB]])
+ // CHECK: memref.copy %[[C]], %[[svB]]
%ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor<?xf32>
// CHECK-NOT: scf.yield
@@ -412,7 +412,7 @@ func @main() {
// CHECK: %[[alloc:.*]] = memref.alloc
// CHECK: %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
-// CHECK: linalg.copy(%[[A]], %[[alloc]])
+// CHECK: memref.copy %[[A]], %[[alloc]]
// CHECK: call @some_external_func(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> ()
call @some_external_func(%A) : (tensor<4xi32>) -> ()
@@ -434,7 +434,7 @@ func @main() {
// CHECK: %[[alloc:.*]] = memref.alloc
// CHECK: %[[B:.*]] = memref.cast %[[alloc]] : memref<4xi32> to memref<4xi32, #[[$DYN_1D_MAP]]>
-// CHECK: linalg.copy(%[[A]], %[[alloc]])
+// CHECK: memref.copy %[[A]], %[[alloc]]
// CHECK: call @some_external_func_within_scf_execute(%[[B]]) : (memref<4xi32, #[[$DYN_1D_MAP]]>) -> ()
scf.execute_region {
call @some_external_func_within_scf_execute(%A) : (tensor<4xi32>) -> ()
@@ -465,11 +465,11 @@ func @scf_for_with_tensor_insert_slice(
-> (tensor<?xf32>, tensor<?xf32>)
{
// CHECK-NEXT: %[[SVA:.*]] = memref.subview %[[A]]
- // CHECK-NEXT: linalg.copy(%[[C]], %[[SVA]]) : memref<4xf32, #[[$DYN_1D_MAP]]>, memref<4xf32, #[[$DYN_1D_MAP]]>
+ // CHECK-NEXT: memref.copy %[[C]], %[[SVA]] : memref<4xf32, #[[$DYN_1D_MAP]]> to memref<4xf32, #[[$DYN_1D_MAP]]>
%ttA = tensor.insert_slice %C into %tA[%i][4][1] : tensor<4xf32> into tensor<?xf32>
// CHECK-NEXT: %[[SVB:.*]] = memref.subview %[[B]]
- // CHECK-NEXT: linalg.copy(%[[C]], %[[SVB]]) : memref<4xf32, #[[$DYN_1D_MAP]]>, memref<4xf32, #[[$DYN_1D_MAP]]>
+ // CHECK-NEXT: memref.copy %[[C]], %[[SVB]] : memref<4xf32, #[[$DYN_1D_MAP]]> to memref<4xf32, #[[$DYN_1D_MAP]]>
%ttB = tensor.insert_slice %C into %tB[%i][4][1] : tensor<4xf32> into tensor<?xf32>
// scf.yield is empty and is elided
@@ -500,7 +500,7 @@ func @bar(
// %r0#0 requires a copy because we have no idea what the function is doing.
// CHECK: %[[alloc:.*]] = memref.alloc
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
-// CHECK: linalg.copy(%[[B]], %[[alloc]])
+// CHECK: memref.copy %[[B]], %[[alloc]]
// CHECK-NEXT: call @some_external_func(%[[casted]]) : (memref<?xf32, #[[$DYN_1D_MAP]]>) -> ()
call @some_external_func(%r0#0) : (tensor<?xf32>) -> ()
@@ -707,7 +707,7 @@ func @tiled_loop_yield_out_of_place(
iterators["parallel"]
{
// CHECK-NOT: alloc
- // CHECK: linalg.copy(%[[B]], %[[A]])
+ // CHECK: memref.copy %[[B]], %[[A]]
linalg.yield %B : tensor<?xf32>
// CHECK: linalg.yield
// CHECK-NOT: tensor
@@ -762,9 +762,9 @@ func @entry(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] ->
// CHECK: %[[ALLOC_B:.*]] = memref.alloc
// CHECK: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]]
// CHECK: %[[ALLOC_A:.*]] = memref.alloc
-// CHECK: linalg.copy(%[[A]], %[[ALLOC_A]])
-// CHECK: linalg.copy(%[[B]], %[[ALLOC_B]])
-// CHECK: linalg.copy(%[[C]], %[[ALLOC_C]])
+// CHECK: memref.copy %[[A]], %[[ALLOC_A]]
+// CHECK: memref.copy %[[B]], %[[ALLOC_B]]
+// CHECK: memref.copy %[[C]], %[[ALLOC_C]]
// CHECK: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]]
// CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]])
call @callee(%A, %B, %C) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> ()
@@ -831,7 +831,7 @@ func @matmul(
// insert_slice is inplace but its source comes from an equivalent buffer
// that is not in place. So we must insert a copy of the small buffer into
// the bigger buffer.
- // CHECK: linalg.copy(%[[ALLOC]], %[[T]])
+ // CHECK: memref.copy %[[ALLOC]], %[[T]]
%7 = tensor.insert_slice %6 into %arg6[%arg3, %arg5] [8, 16] [1, 1] :
tensor<8x16xf32> into tensor<128x192xf32>
@@ -848,8 +848,9 @@ func @matmul(
// CHECK-LABEL: func @tensor_cast_not_in_place(
// CHECK-SAME: %[[A:.*]]: memref<?xf32{{.*}}>, %[[B:.*]]: memref<?xf32{{.*}}>
// CHECK: %[[alloc:.*]] = memref.alloc
-// CHECK: linalg.copy(%[[A]], %[[alloc]])
-// CHECK: %[[cast:.*]] = memref.cast %[[alloc]]
+// CHECK: memref.copy %[[A]], %[[alloc]]
+// CHECK: %[[subview:.*]] = memref.subview %[[A]][{{.*}}] [4] [1] : {{.*}} to memref<4xf32
+// CHECK: memref.copy %[[alloc]], %[[subview]]
func @tensor_cast_not_in_place(
%A : tensor<?xf32> {linalg.inplaceable = true},
%B : tensor<?xf32> {linalg.inplaceable = false}, %idx: index)
@@ -1014,7 +1015,7 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
%1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
// CHECK: %[[alloc:.*]] = memref.alloc
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
- // CHECK: linalg.copy(%[[arg0]], %[[alloc]])
+ // CHECK: memref.copy %[[arg0]], %[[alloc]]
// CHECK: call @inner_func_2(%[[casted]])
%3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
scf.yield %t1 : tensor<?xf32>
@@ -1143,7 +1144,7 @@ func @linalg_op_bufferizes_out_of_place_with_input(
%t3: tensor<?x?xf32> {linalg.inplaceable = false},
%s1: index, %s2: index, %cst: f32) -> tensor<?x?xf32> {
// CHECK: %[[alloc:.*]] = memref.alloc
- // CHECK: linalg.copy(%[[t1]], %[[alloc]])
+ // CHECK: memref.copy %[[t1]], %[[alloc]]
// CHECK: linalg.generic {{.*}} ins(%[[t1]], %[[t2]] : {{.*}}) outs(%[[alloc]] : {{.*}})
%r = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
@@ -1203,7 +1204,7 @@ func @op_is_reading_but_following_ops_are_not(
{
// Make sure that a copy is inserted here.
// CHECK: %[[ALLOC:.*]] = memref.alloc
- // CHECK: linalg.copy(%[[t0]], %[[ALLOC]])
+ // CHECK: memref.copy %[[t0]], %[[ALLOC]]
// CHECK: linalg.generic {{.*}} outs(%[[ALLOC]] : memref
%r0 =linalg.generic #trait outs (%t0 : tensor<?xf32>) {
^bb(%0: f32) :
@@ -1257,7 +1258,7 @@ func @write_to_select_op_source(
%cst = arith.constant 0.0 : f32
%idx = arith.constant 0 : index
// CHECK: %[[alloc:.*]] = memref.alloc
- // CHECK: linalg.copy(%[[t1]], %[[alloc]])
+ // CHECK: memref.copy %[[t1]], %[[alloc]]
// CHECK: memref.store %{{.*}}, %[[alloc]]
%w = tensor.insert %cst into %t1[%idx] : tensor<?xf32>
// CHECK: %[[select:.*]] = select %{{.*}}, %[[t1]], %[[t2]]
@@ -1281,7 +1282,7 @@ func @write_after_select_read_one(
// CHECK: %[[alloc:.*]] = memref.alloc
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
- // CHECK: linalg.copy(%[[t1]], %[[alloc]])
+ // CHECK: memref.copy %[[t1]], %[[alloc]]
// CHECK: %[[select:.*]] = select %{{.*}}, %[[casted]], %[[t2]]
%s = std.select %c, %t1, %t2 : tensor<?xf32>
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
index dc6a8ab63ef35..648a7ee2832bc 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-comprehensive-bufferize.mlir
@@ -3,7 +3,7 @@
// RUN: mlir-opt -canonicalize -convert-scf-to-std -convert-vector-to-llvm -convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \
// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
-// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext |\
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext,%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext |\
// RUN: FileCheck %s
#map0 = affine_map<(d0, d1)[s0] -> ((d1 - d0) ceildiv s0)>
More information about the Mlir-commits
mailing list