[Mlir-commits] [mlir] 6cdd34b - [mlir][tensor][bufferize] Bufferize inserts into equivalent tensors in-place
Matthias Springer
llvmlistbot at llvm.org
Wed Oct 5 23:14:52 PDT 2022
Author: Matthias Springer
Date: 2022-10-06T15:06:33+09:00
New Revision: 6cdd34b9739a41caeda14f63bceab6fec7fd0ae5
URL: https://github.com/llvm/llvm-project/commit/6cdd34b9739a41caeda14f63bceab6fec7fd0ae5
DIFF: https://github.com/llvm/llvm-project/commit/6cdd34b9739a41caeda14f63bceab6fec7fd0ae5.diff
LOG: [mlir][tensor][bufferize] Bufferize inserts into equivalent tensors in-place
Inserting a tensor into an equivalent tensor is a no-op after bufferization. No alloc is needed.
Differential Revision: https://reviews.llvm.org/D132662
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index cd2be2925c68..2c9dd66c45c4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -390,8 +390,12 @@ class AnalysisState {
/// In the above example, Values with a star satisfy the condition. When
/// starting the traversal from Value 1, the resulting SetVector is:
/// { 2, 7, 8, 5 }
- SetVector<Value> findValueInReverseUseDefChain(
- Value value, llvm::function_ref<bool(Value)> condition) const;
+ ///
+ /// If `followEquivalentOnly` is set, only equivalent OpOperands are selected.
+ SetVector<Value>
+ findValueInReverseUseDefChain(Value value,
+ llvm::function_ref<bool(Value)> condition,
+ bool followEquivalentOnly = false) const;
/// Find the Values of the last preceding write of a given Value.
///
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 0d3495100fda..3c135068af1a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -398,7 +398,8 @@ bool AnalysisState::isValueRead(Value value) const {
// evaluates to true. OpOperands of such matching Values are not traversed any
// further.
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
- Value value, llvm::function_ref<bool(Value)> condition) const {
+ Value value, llvm::function_ref<bool(Value)> condition,
+ bool followEquivalentOnly) const {
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
@@ -410,8 +411,19 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
}
OpResult opResult = value.cast<OpResult>();
+ BufferizableOpInterface bufferizableOp =
+ options.dynCastBufferizableOp(opResult.getDefiningOp());
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
- if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
+
+ // Stop iterating in either one of these cases:
+ // * The current op is not bufferizable or excluded in the filter.
+ // * There are no OpOperands to follow.
+ // * There is an OpOperand, but it is not an equivalent tensor (only if
+ // `followEquivalentOnly` is set).
+ if (!bufferizableOp || opOperands.empty() ||
+ (followEquivalentOnly &&
+ bufferizableOp.bufferRelation(opResult, *this) !=
+ BufferRelation::Equivalent)) {
result.insert(value);
continue;
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 8a92129675ca..16e84a42064f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -611,9 +611,9 @@ struct InsertOpInterface
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
template <typename OpTy>
-static bool areEquivalentExtractSliceOps(const AnalysisState &state,
- ExtractSliceOp extractSliceOp,
- OpTy insertSliceOp) {
+static bool areEquivalentSlices(const AnalysisState &state,
+ ExtractSliceOp extractSliceOp,
+ OpTy insertSliceOp) {
if (!extractSliceOp || !insertSliceOp)
return false;
if (extractSliceOp != insertSliceOp &&
@@ -626,20 +626,31 @@ static bool areEquivalentExtractSliceOps(const AnalysisState &state,
return true;
}
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
+/// Return true if `value` is originating from the InsertSliceOp's destination
+/// or an ExtractSliceOp that matches the given InsertSliceOp.
template <typename OpTy>
-static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
- OpTy insertSliceOp) {
- auto condition = [&](Value val) {
+static bool matchesInsertDestination(const AnalysisState &state, Value value,
+ OpTy insertSliceOp) {
+ // Look for matching slices.
+ auto matchesSlice = [&](Value val) {
if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
- if (areEquivalentExtractSliceOps(state, extractSliceOp, insertSliceOp))
+ if (areEquivalentSlices(state, extractSliceOp, insertSliceOp))
return true;
return false;
};
+ if (llvm::all_of(state.findValueInReverseUseDefChain(value, matchesSlice),
+ matchesSlice))
+ return true;
- return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
- condition);
+ // Look for equivalent values.
+ auto isEquivalent = [&](Value val) {
+ return state.areEquivalentBufferizedValues(val, insertSliceOp.getDest());
+ };
+ if (llvm::all_of(state.findValueInReverseUseDefChain(
+ value, isEquivalent, /*followEquivalentOnly=*/true),
+ isEquivalent))
+ return true;
+ return false;
}
template <typename OpTy>
@@ -661,8 +672,8 @@ static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
- insertSliceOp))
+ matchesInsertDestination(state, uConflictingWrite->get(),
+ insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
@@ -679,7 +690,7 @@ static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
+ matchesInsertDestination(state, uRead->get(), insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
@@ -712,8 +723,8 @@ static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
state.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.getSource()) &&
- hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
- insertSliceOp))
+ matchesInsertDestination(state, insertSliceOp.getSource(),
+ insertSliceOp))
return true;
return false;
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 3fc9f1ce1fc9..e80027013d1c 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -126,15 +126,12 @@ func.func @insert_slice_fun_not_inplace(
// -----
-// CHECK-LABEL: func @tensor_cast_not_in_place(
-// CHECK-SAME: %[[A:.*]]: memref<?xf32{{.*}}>, %[[B:.*]]: memref<?xf32{{.*}}>
-// CHECK: %[[alloc:.*]] = memref.alloc
-// CHECK: memref.copy %[[A]], %[[alloc]]
+// CHECK-LABEL: func @tensor_cast_in_place(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32{{.*}}>
// CHECK: %[[subview:.*]] = memref.subview %[[A]][{{.*}}] [4] [1] : {{.*}} to memref<4xf32
-// CHECK: memref.copy %[[alloc]], %[[subview]]
-func.func @tensor_cast_not_in_place(
- %A : tensor<?xf32> {bufferization.writable = true},
- %B : tensor<?xf32> {bufferization.writable = false}, %idx: index)
+// CHECK: memref.copy %[[A]], %[[subview]]
+func.func @tensor_cast_in_place(
+ %A : tensor<?xf32> {bufferization.writable = true}, %idx: index)
-> (tensor<?xf32>)
{
%r0 = tensor.cast %A : tensor<?xf32> to tensor<4xf32>
@@ -243,3 +240,16 @@ func.func @dealloc_pad_buffer(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
%r = tensor.extract %0[%idx, %idx] : tensor<?x?xindex>
return %r : index
}
+
+// -----
+
+// CHECK-LABEL: func @insert_equivalent_tensor
+func.func @insert_equivalent_tensor(%t: tensor<10xf32>) -> tensor<10xf32> {
+ // CHECK-NOT: memref.alloc
+ %cst = arith.constant 4.200000e+01 : f32
+ // CHECK: linalg.fill
+ %0 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32>
+ // CHECK-NOT: memref.copy
+ %1 = tensor.insert_slice %0 into %t[0][10][1] : tensor<10xf32> into tensor<10xf32>
+ return %1 : tensor<10xf32>
+}
More information about the Mlir-commits
mailing list