[Mlir-commits] [mlir] d2608ad - [mlir][bufferize] Do not insert useless casts for newly allocated buffers
Matthias Springer
llvmlistbot at llvm.org
Fri Apr 8 02:13:34 PDT 2022
Author: Matthias Springer
Date: 2022-04-08T18:12:02+09:00
New Revision: d2608adf490c10afc71d57141d61a9df5464fd82
URL: https://github.com/llvm/llvm-project/commit/d2608adf490c10afc71d57141d61a9df5464fd82
DIFF: https://github.com/llvm/llvm-project/commit/d2608adf490c10afc71d57141d61a9df5464fd82.diff
LOG: [mlir][bufferize] Do not insert useless casts for newly allocated buffers
Differential Revision: https://reviews.llvm.org/D123369
Added:
Modified:
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 8c68db885de63..091462f1ed73a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -466,7 +466,6 @@ FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
// Compute allocation memref type.
assert(shapedValue.getType().isa<ShapedType>());
- MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
SmallVector<Value> dynShape;
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
@@ -485,17 +484,7 @@ FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
}
// Create the buffer allocation.
- Value alloc =
- createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc);
-
- // Insert a cast if a
diff erent type was requested.
- if (memRefType && memRefType != allocMemRefType) {
- assert(memref::CastOp::areCastCompatible(allocMemRefType, memRefType) &&
- "createAlloc: cast incompatible");
- alloc = b.create<memref::CastOp>(loc, memRefType, alloc);
- }
-
- return alloc;
+ return createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc);
}
/// Create a memory copy between two memref buffers.
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index fde24fd8b7ab4..7c6bbcd414d71 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -411,7 +411,22 @@ struct ForOpInterface
*yieldedAlloc, state.getOptions());
(void)copyStatus;
assert(succeeded(copyStatus) && "could not create memcpy");
- return *yieldedAlloc;
+
+ if (yieldedVal.getType() == yieldedAlloc->getType())
+ return *yieldedAlloc;
+
+ // The iter_arg memref type has a layout map. Cast the new buffer to
+ // the same type.
+ // TODO: In case the iter_arg has a layout map that is not the fully
+ // dynamic one, we cannot cast the new buffer. In that case, the
+ // iter_arg must be changed to the fully dynamic layout map. (And then
+ // the new buffer can be casted.)
+ assert(memref::CastOp::areCastCompatible(yieldedAlloc->getType(),
+ yieldedVal.getType()) &&
+ "scf.for op bufferization: cast incompatible");
+ Value casted = rewriter.create<memref::CastOp>(
+ val.getLoc(), yieldedVal.getType(), *yieldedAlloc);
+ return casted;
});
yieldOp.getResultsMutable().assign(yieldValues);
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
index 2fdb0a45e9d34..efc3038820ace 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir
@@ -193,12 +193,11 @@ func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
// CHECK-TENSOR: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
%c0 = arith.constant 0 : index
// CHECK-TENSOR: %[[alloc:.*]] = memref.alloc
- // CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]]
- // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
+ // CHECK-TENSOR: %[[casted_alloc:.*]] = bufferization.to_tensor %[[alloc]]
// CHECK-TENSOR: memref.copy %[[t1_memref]], %[[alloc]]
// CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]]
%0 = tensor.insert %f into %t1[%c0] : tensor<?xf32>
- // CHECK-TENSOR: return %[[casted_tensor]]
+ // CHECK-TENSOR: return %[[casted_alloc]]
return %0 : tensor<?xf32>
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index ac2249da4282c..3d8d09460484a 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -29,10 +29,9 @@ func @return_tensor(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
// CHECK: %[[dim:.*]] = tensor.dim %[[A]]
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
- // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
- // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]]
+ // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
// CHECK: memref.dealloc %[[alloc]]
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index b8f9dcf0149b8..7ed7bd13321b1 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -308,16 +308,16 @@ func @insert_slice_fun_not_inplace(
// CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-// CHECK-LABEL: func @scf_for_yield_only
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #[[$map_1d_dyn]]>
+// CHECK-LABEL: func @scf_for_yield_only(
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #[[$map_1d_dyn]]>,
// CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref<?xf32, #[[$map_1d_dyn]]>
+// CHECK-SAME: ) -> memref<?xf32> {
func @scf_for_yield_only(%A : tensor<?xf32> {linalg.inplaceable = false},
%B : tensor<?xf32> {linalg.inplaceable = true},
%lb : index, %ub : index, %step : index)
-> (tensor<?xf32>, tensor<?xf32>)
{
// CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc
- // CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]]
// CHECK: memref.copy %[[A]], %[[ALLOC_FOR_A]]
// The first scf.for remains but just turns into dead code.
@@ -330,7 +330,7 @@ func @scf_for_yield_only(%A : tensor<?xf32> {linalg.inplaceable = false},
scf.yield %t : tensor<?xf32>
}
- // CHECK: return %[[CASTED]] : memref<?xf32, #[[$map_1d_dyn]]>
+ // CHECK: return %[[ALLOC_FOR_A]] : memref<?xf32>
// CHECK-NOT: dealloc
return %r0, %r1: tensor<?xf32>, tensor<?xf32>
}
@@ -373,7 +373,6 @@ func @scf_for_with_tensor.insert_slice(
-> (tensor<?xf32>, tensor<?xf32>)
{
// CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc
- // CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]]
// CHECK: memref.copy %[[A]], %[[ALLOC_FOR_A]]
// CHECK: %[[svA:.*]] = memref.subview %[[ALLOC_FOR_A]][0] [4] [1]
@@ -396,7 +395,7 @@ func @scf_for_with_tensor.insert_slice(
scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
}
- // CHECK: return %[[CASTED]] : memref<?xf32, #[[$map_1d_dyn]]>
+ // CHECK: return %[[ALLOC_FOR_A]] : memref<?xf32>
return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
}
@@ -418,7 +417,6 @@ func @execute_region_with_conflict(%t1 : tensor<?xf32> {linalg.inplaceable = "tr
// memref.store is left over.
// CHECK: %[[alloc:.*]] = memref.alloc
- // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
// CHECK: memref.copy %[[m1]], %[[alloc]]
// CHECK: memref.store %{{.*}}, %[[alloc]][%{{.*}}]
%0, %1, %2 = scf.execute_region -> (f32, tensor<?xf32>, f32) {
@@ -426,6 +424,7 @@ func @execute_region_with_conflict(%t1 : tensor<?xf32> {linalg.inplaceable = "tr
scf.yield %f1, %t2, %f1 : f32, tensor<?xf32>, f32
}
+ // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
// CHECK: %[[load:.*]] = memref.load %[[m1]]
%3 = tensor.extract %t1[%idx] : tensor<?xf32>
@@ -783,8 +782,8 @@ func @write_after_select_read_one(
%idx = arith.constant 0 : index
// CHECK: %[[alloc:.*]] = memref.alloc
- // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
- // CHECK: memref.copy %[[t1]], %[[alloc]]
+ // CHECK-DAG: %[[casted:.*]] = memref.cast %[[alloc]]
+ // CHECK-DAG: memref.copy %[[t1]], %[[alloc]]
// CHECK: %[[select:.*]] = arith.select %{{.*}}, %[[casted]], %[[t2]]
%s = arith.select %c, %t1, %t2 : tensor<?xf32>
@@ -859,13 +858,11 @@ func @scf_execute_region_yield_non_equivalent(%i: index, %j: index) -> f32 {
// CHECK-LABEL: func @scf_for_yield_non_equivalent(
// CHECK-SAME: %[[t:.*]]: memref<?xf32
// CHECK: %[[alloc:.*]] = memref.alloc(%{{.*}})
-// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
-// CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[casted]])
+// CHECK: %[[for:.*]] = scf.for {{.*}} iter_args(%[[iter:.*]] = %[[alloc]])
// CHECK: memref.dealloc %[[iter]]
// CHECK: %[[alloc2:.*]] = memref.alloc(%{{.*}})
// CHECK: memref.copy %[[t]], %[[alloc2]]
-// CHECK: %[[casted2:.*]] = memref.cast %[[alloc2]]
-// CHECK: scf.yield %[[casted2]]
+// CHECK: scf.yield %[[alloc2]]
// CHECK: return %[[for]]
func @scf_for_yield_non_equivalent(
%t: tensor<?xf32>, %lb : index, %ub : index, %step : index) -> tensor<?xf32> {
diff --git a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir
index 72661713a5b44..24fe44ca58a34 100644
--- a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir
@@ -358,7 +358,7 @@ func @bar(
// %r0#0 requires a copy because we have no idea what the function is doing.
// CHECK-DAG: %[[alloc:.*]] = memref.alloc
// CHECK-DAG: %[[casted:.*]] = memref.cast %[[alloc]]
-// CHECK: memref.copy %[[B]], %[[alloc]]
+// CHECK-DAG: memref.copy %[[B]], %[[alloc]]
// CHECK-NEXT: call @some_external_func(%[[casted]]) : (memref<?xf32, #[[$DYN_1D_MAP]]>) -> ()
call @some_external_func(%r0#0) : (tensor<?xf32>) -> ()
@@ -475,15 +475,15 @@ func @entry(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] ->
// conflict. However, inside `entry`, the writes do cause a conflict because
// %A, %B and %C are not inplaceable. This test case shows that this kind of
// conflict detection has a "transitive" nature.
-// CHECK: %[[ALLOC_C:.*]] = memref.alloc
-// CHECK: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]]
-// CHECK: %[[ALLOC_B:.*]] = memref.alloc
-// CHECK: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]]
-// CHECK: %[[ALLOC_A:.*]] = memref.alloc
-// CHECK: memref.copy %[[A]], %[[ALLOC_A]]
-// CHECK: memref.copy %[[B]], %[[ALLOC_B]]
-// CHECK: memref.copy %[[C]], %[[ALLOC_C]]
-// CHECK: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]]
+// CHECK-DAG: %[[ALLOC_C:.*]] = memref.alloc
+// CHECK-DAG: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]]
+// CHECK-DAG: %[[ALLOC_B:.*]] = memref.alloc
+// CHECK-DAG: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]]
+// CHECK-DAG: %[[ALLOC_A:.*]] = memref.alloc
+// CHECK-DAG: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]]
+// CHECK-DAG: memref.copy %[[A]], %[[ALLOC_A]]
+// CHECK-DAG: memref.copy %[[B]], %[[ALLOC_B]]
+// CHECK-DAG: memref.copy %[[C]], %[[ALLOC_C]]
// CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]])
call @callee(%A, %B, %C) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> ()
return
@@ -539,8 +539,8 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
// CHECK: scf.for {{.*}} {
%1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
// CHECK: %[[alloc:.*]] = memref.alloc
- // CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
- // CHECK: memref.copy %[[arg0]], %[[alloc]]
+ // CHECK-DAG: %[[casted:.*]] = memref.cast %[[alloc]]
+ // CHECK-DAG: memref.copy %[[arg0]], %[[alloc]]
// CHECK: call @inner_func_2(%[[casted]])
// CHECK: memref.dealloc %[[alloc]]
// CHECK-NOT: scf.yield
More information about the Mlir-commits
mailing list