[Mlir-commits] [mlir] 13593dc - [mlir][tensor][bufferize] Fix tensor.insert_slice regression
Matthias Springer
llvmlistbot at llvm.org
Sat Nov 26 10:20:52 PST 2022
Author: Matthias Springer
Date: 2022-11-26T19:14:33+01:00
New Revision: 13593dc9dc5a8f587402bf1e5f180f2c0fc750ee
URL: https://github.com/llvm/llvm-project/commit/13593dc9dc5a8f587402bf1e5f180f2c0fc750ee
DIFF: https://github.com/llvm/llvm-project/commit/13593dc9dc5a8f587402bf1e5f180f2c0fc750ee.diff
LOG: [mlir][tensor][bufferize] Fix tensor.insert_slice regression
This reverts D132662 (apart from overall cleanups), which introduced a too aggressive optimization for tensor.insert_slice bufferization. Instead, bufferizesToMemoryRead is improved to handle some of these cases. The remaining cases can still bufferize efficiently when running the canonicalizer before the bufferization.
Differential Revision: https://reviews.llvm.org/D138745
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 6ce05c257d94f..528d83f76e050 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -615,8 +615,8 @@ static bool areEquivalentSlices(const AnalysisState &state,
return true;
}
-/// Return true if `value` is originating from the InsertSliceOp's destination
-/// or an ExtractSliceOp that matches the given InsertSliceOp.
+/// Return true if `value` is originating from an ExtractSliceOp that matches
+/// the given InsertSliceOp.
template <typename OpTy>
static bool matchesInsertDestination(const AnalysisState &state, Value value,
OpTy insertSliceOp) {
@@ -630,15 +630,6 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
if (llvm::all_of(state.findValueInReverseUseDefChain(value, matchesSlice),
matchesSlice))
return true;
-
- // Look for equivalent values.
- auto isEquivalent = [&](Value val) {
- return state.areEquivalentBufferizedValues(val, insertSliceOp.getDest());
- };
- if (llvm::all_of(state.findValueInReverseUseDefChain(
- value, isEquivalent, /*followEquivalentOnly=*/true),
- isEquivalent))
- return true;
return false;
}
@@ -727,6 +718,36 @@ static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
struct InsertSliceOpInterface
: public DstBufferizableOpInterfaceExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
+ auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
+ RankedTensorType destType = insertSliceOp.getDestType();
+
+ // The source is always read.
+ if (&opOperand == &op->getOpOperand(0) /*src*/)
+ return true;
+
+ // For the destination, it depends...
+ assert(&opOperand == &insertSliceOp->getOpOperand(1) && "expected dest");
+
+ // Dest is not read if it is entirely overwritten. E.g.:
+ // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
+ bool allOffsetsZero =
+ llvm::all_of(insertSliceOp.getMixedOffsets(), [](OpFoldResult ofr) {
+ return isConstantIntValue(ofr, 0);
+ });
+ bool sizesMatchDestSizes = llvm::all_of(
+ llvm::enumerate(insertSliceOp.getMixedSizes()), [&](auto &it) {
+ return getConstantIntValue(it.value()) ==
+ destType.getDimSize(it.index());
+ });
+ bool allStridesOne =
+ llvm::all_of(insertSliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+ return isConstantIntValue(ofr, 1);
+ });
+ return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
+ }
+
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
const AnalysisState &state) const {
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index a1dffec50d763..405f112782a74 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -126,12 +126,18 @@ func.func @insert_slice_fun_not_inplace(
// -----
-// CHECK-LABEL: func @tensor_cast_in_place(
-// CHECK-SAME: %[[A:.*]]: memref<?xf32{{.*}}>
+// This test case could bufferize in-place with a better analysis. However, it
+// is simpler to let the canonicalizer fold away the tensor.insert_slice.
+
+// CHECK-LABEL: func @tensor_cast_not_in_place(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32{{.*}}>, %[[B:.*]]: memref<?xf32{{.*}}>
+// CHECK: %[[alloc:.*]] = memref.alloc
+// CHECK: memref.copy %[[A]], %[[alloc]]
// CHECK: %[[subview:.*]] = memref.subview %[[A]][{{.*}}] [4] [1] : {{.*}} to memref<4xf32
-// CHECK: memref.copy %[[A]], %[[subview]]
-func.func @tensor_cast_in_place(
- %A : tensor<?xf32> {bufferization.writable = true}, %idx: index)
+// CHECK: memref.copy %[[alloc]], %[[subview]]
+func.func @tensor_cast_not_in_place(
+ %A : tensor<?xf32> {bufferization.writable = true},
+ %B : tensor<?xf32> {bufferization.writable = false}, %idx: index)
-> (tensor<?xf32>)
{
%r0 = tensor.cast %A : tensor<?xf32> to tensor<4xf32>
@@ -241,13 +247,16 @@ func.func @dealloc_pad_buffer(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
// -----
+// This test case could bufferize in-place with a better analysis. However, it
+// is simpler to let the canonicalizer fold away the tensor.insert_slice.
+
// CHECK-LABEL: func @insert_equivalent_tensor
func.func @insert_equivalent_tensor(%t: tensor<10xf32>) -> tensor<10xf32> {
- // CHECK-NOT: memref.alloc
+ // CHECK: memref.alloc
%cst = arith.constant 4.200000e+01 : f32
// CHECK: linalg.fill
%0 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32>
- // CHECK-NOT: memref.copy
+ // CHECK: memref.copy
%1 = tensor.insert_slice %0 into %t[0][10][1] : tensor<10xf32> into tensor<10xf32>
return %1 : tensor<10xf32>
}
@@ -279,3 +288,45 @@ func.func @pad_memory_space(%t: tensor<?xf32>, %h1: index, %f: f32, %pos: index)
// CHECK-DAG: memref.dealloc %[[padded_alloc]]
return %2 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_regression(
+// CHECK-SAME: %[[t:.*]]: memref<10xf32,{{.*}}>, %[[b:.*]]: memref<5xf32
+func.func @insert_slice_regression(%t: tensor<10xf32>, %b: tensor<5xf32>) -> tensor<10xf32> {
+ %cst = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10xf32>
+ // CHECK: linalg.fill {{.*}} outs(%[[alloc]] : memref<10xf32>)
+ %1 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32>
+
+ // Read %1 so that it does not DCE away.
+ %vec = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<10xf32>
+ vector.print %vec : vector<10xf32>
+
+ // Write back a
diff erent value (not %1).
+ // CHECK: %[[subview:.*]] = memref.subview %[[t]][0] [5] [1]
+ // CHECK: memref.copy %[[b]], %[[subview]]
+ %2 = tensor.insert_slice %b into %t[0][5][1] : tensor<5xf32> into tensor<10xf32>
+ return %2 : tensor<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @insert_slice_full_overwrite(
+// CHECK-SAME: %[[t:.*]]: memref<10xf32,{{.*}}>, %[[b:.*]]: memref<10xf32,{{.*}}>
+func.func @insert_slice_full_overwrite(%t: tensor<10xf32>, %b: tensor<10xf32>) -> tensor<10xf32> {
+ %cst = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ // CHECK: linalg.fill {{.*}} outs(%[[t]] : memref<10xf32,{{.*}}>)
+ %1 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32>
+
+ // Read %1 so that it does not DCE away.
+ %vec = vector.transfer_read %1[%c0], %cst : tensor<10xf32>, vector<10xf32>
+ vector.print %vec : vector<10xf32>
+
+ // Write back a
diff erent value (not %1).
+ // CHECK: memref.copy %[[b]], %[[t]]
+ %2 = tensor.insert_slice %b into %t[0][10][1] : tensor<10xf32> into tensor<10xf32>
+ return %2 : tensor<10xf32>
+}
More information about the Mlir-commits
mailing list