[Mlir-commits] [mlir] [mlir][bufferization]-Refactor findValueInReverseUseDefChain to accept opOperand (PR #121304)
Amir Bishara
llvmlistbot at llvm.org
Sun Dec 29 14:06:44 PST 2024
https://github.com/amirBish created https://github.com/llvm/llvm-project/pull/121304
Edit the `findValueInReverseUseDefChain` method to accept `OpOperand` instead of the `Value` type, This change will make sure that the populated `visitedOpOperands` argument is fully accurate and contains the opOperand we have started the reverse chain from.
>From c883d9e10933b744aa33e03b038283a2bb9667ba Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Mon, 30 Dec 2024 00:01:20 +0200
Subject: [PATCH] [mlir][bufferization]-Refactor findValueInReverseUseDefChain
to accept opOperand
Edit the `findValueInReverseUseDefChain` method to accept `OpOperand`
instead of the `Value` type, This change will make sure that the
populated `visitedOpOperands` argument is fully accurate and contains
the opOperand we have started the reverse chain from.
---
.../IR/BufferizableOpInterface.h | 6 +--
.../Transforms/OneShotAnalysis.h | 6 +--
.../IR/BufferizableOpInterface.cpp | 17 +++++----
.../Transforms/EmptyTensorElimination.cpp | 8 ++--
.../Transforms/OneShotAnalysis.cpp | 38 +++++++++++--------
.../Transforms/ConvertToDestinationStyle.cpp | 2 +-
.../Transforms/EliminateEmptyTensors.cpp | 2 +-
...ot-bufferize-empty-tensor-elimination.mlir | 11 ++++++
8 files changed, 54 insertions(+), 36 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 983f7a29cb2206..d1a102e2a6e4e8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -456,7 +456,7 @@ class AnalysisState {
/// read by themselves (e.g., ExtractSliceOp).
bool isValueRead(Value value) const;
- /// Starting from `value`, follow the use-def chain in reverse, always
+ /// Starting from `opOperand`, follow the use-def chain in reverse, always
/// selecting the aliasing OpOperands. Find and return Values for which
/// `condition` evaluates to true. OpOperands of such matching Values are not
/// traversed any further, the visited aliasing opOperands will be preserved
@@ -484,7 +484,7 @@ class AnalysisState {
/// Additional stopping conditions for the traversal can be specified in
/// `config`.
SetVector<Value> findValueInReverseUseDefChain(
- Value value, llvm::function_ref<bool(Value)> condition,
+ OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
TraversalConfig config = TraversalConfig(),
llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
@@ -520,7 +520,7 @@ class AnalysisState {
///
/// Note: OpResults of unknown ops are handled conservatively and assumed to
/// be definitions.
- SetVector<Value> findDefinitions(Value value) const;
+ SetVector<Value> findDefinitions(OpOperand *opOperand) const;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
virtual bool isInPlace(OpOperand &opOperand) const;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index d50a3042aeeacf..da3094a6d6f546 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -127,9 +127,9 @@ class OneShotAnalysisState : public AnalysisState {
/// Return true if the buffer of the given tensor value is writable.
bool isWritable(Value value) const;
- /// Find the definitions of the given tensor value or retrieve them from the
- /// cache.
- const SetVector<Value> &findDefinitionsCached(Value value);
+ /// Find the definitions of the given tensor value related to `opOperand` or
+ /// retrieve them from the cache.
+ const SetVector<Value> &findDefinitionsCached(OpOperand *opOperand);
/// Reset cached data structures.
void resetCache() override;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 349841f06959c3..7ca9659ef86ee2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -480,18 +480,21 @@ bool AnalysisState::isValueRead(Value value) const {
return false;
}
-// Starting from `value`, follow the use-def chain in reverse, always selecting
+// Starting from `opOperand`, follow the use-def chain in reverse, always selecting
// the aliasing OpOperands. Find and return Values for which `condition`
// evaluates to true. OpOperands of such matching Values are not traversed any
// further, the visited aliasing opOperands will be preserved through
// `visitedOpOperands`.
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
- Value value, llvm::function_ref<bool(Value)> condition,
+ OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
TraversalConfig config,
llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
llvm::DenseSet<Value> visited;
llvm::SetVector<Value> result, workingSet;
- workingSet.insert(value);
+ workingSet.insert(opOperand->get());
+
+ if (visitedOpOperands)
+ visitedOpOperands->insert(opOperand);
while (!workingSet.empty()) {
Value value = workingSet.pop_back_val();
@@ -563,12 +566,12 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
return result;
}
-// Find the values that define the contents of the given value.
-llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
+// Find the values that define the contents of the given opOperand.
+llvm::SetVector<Value> AnalysisState::findDefinitions(OpOperand *opOperand) const {
TraversalConfig config;
config.alwaysIncludeLeaves = false;
return findValueInReverseUseDefChain(
- value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
+ opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
}
AnalysisState::AnalysisState(const BufferizationOptions &options)
@@ -892,7 +895,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
config.alwaysIncludeLeaves = false;
for (AliasingOpOperand alias : opOperands) {
if (!state
- .findValueInReverseUseDefChain(alias.opOperand->get(),
+ .findValueInReverseUseDefChain(alias.opOperand,
isMemoryWriteInsideOp, config)
.empty())
return true;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 98c3d8d0adc6d2..84c2da6df093bd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -143,7 +143,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
// %3 = tensor.insert_slice %2 into ...
config.followSameTypeOrCastsOnly = true;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
- source.get(), /*condition=*/
+ &source, /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
&visitedOpOperands);
@@ -155,10 +155,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
return llvm::count(emptyTensorOp->getUses(), *opOperand);
});
- // This could be achieved when a use of `emptyTensorOp` is being
- // consumed by `SubsetInsertionOpInterface`'s source directly.
- if (iter == visitedOpOperands.end())
- continue;
+
+ assert (iter != visitedOpOperands.end());
OpOperand *useToBeReplaced = *iter;
Operation *user = useToBeReplaced->getOwner();
auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index d1e6acef324fbd..2f50b0f02876dd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
// If there is no preceding definition, the tensor contents are
// undefined.
- if (findDefinitionsCached(opResult).empty())
+ if (opResult.getUses().empty())
+ return WalkResult::skip();
+ // It does not really matter which use to take to search about
+ // the value's definitions.
+ OpOperand *opOperand = &(*opResult.getUses().begin());
+ if (findDefinitionsCached(opOperand).empty())
for (OpOperand &use : opResult.getUses())
undefinedTensorUses.insert(&use);
}
@@ -464,20 +469,20 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
/// indexing. I.e., the tensor types do not change along the use-def chain,
/// apart from static <-> dynamic dim casts.
static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
- Value start, Value other) {
+ OpOperand *start, OpOperand *other) {
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
config.followSameTypeOrCastsOnly = true;
return !state
.findValueInReverseUseDefChain(
- start, [&](Value v) { return v == other; }, config)
+ start, [&](Value v) { return v == other->get(); }, config)
.empty();
}
-/// Return "true" if `value` is originating from a subset that is equivalent to
+/// Return "true" if `opOperand` is originating from a subset that is equivalent to
/// the subset that `subsetOp` inserts into.
-static bool matchesInsertDestination(const AnalysisState &state, Value value,
+static bool matchesInsertDestination(const AnalysisState &state, OpOperand *opOperand,
SubsetInsertionOpInterface subsetOp) {
auto matchingSubset = [&](Value val) {
if (auto opResult = dyn_cast<OpResult>(val))
@@ -490,7 +495,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
// There may be multiple leaves at which the reverse SSA use-def chain lookup
// terminates. All of them must be equivalent subsets.
SetVector<Value> backwardSlice =
- state.findValueInReverseUseDefChain(value, matchingSubset);
+ state.findValueInReverseUseDefChain(opOperand, matchingSubset);
return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
}
@@ -516,7 +521,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
// {inplace= [true] }
if (uRead == &subsetOp.getDestinationOperand() &&
- matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
+ matchesInsertDestination(state, uConflictingWrite, subsetOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
@@ -533,7 +538,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
if (uRead == &subsetOp.getSourceOperand() &&
uConflictingWrite == &subsetOp.getDestinationOperand() &&
- matchesInsertDestination(state, uRead->get(), subsetOp))
+ matchesInsertDestination(state, uRead, subsetOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
@@ -567,7 +572,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
state.areEquivalentBufferizedValues(
uRead->get(), subsetOp.getSourceOperand().get()) &&
- matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
+ matchesInsertDestination(state, &subsetOp.getSourceOperand(),
subsetOp))
return true;
@@ -601,7 +606,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// the contents of the buffer.
SetVector<Value> definitionsOrLeaves =
state.findValueInReverseUseDefChain(
- uConflictingWrite->get(),
+ uConflictingWrite,
[&](Value v) { return state.bufferizesToMemoryWrite(v); });
assert(!definitionsOrLeaves.empty() &&
"expected at least one definition or leaf");
@@ -642,7 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// definition is %0. Note that operations that create an alias but do not
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
const SetVector<Value> &definitions =
- state.findDefinitionsCached(uRead->get());
+ state.findDefinitionsCached(uRead);
if (definitions.empty()) {
// Fast path: No conflict if there are no definitions.
LLVM_DEBUG(llvm::dbgs()
@@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (bufferizableOp.bufferizesToElementwiseAccess(
state, {uRead, uConflictingWrite})) {
if (hasEquivalentValueInReverseUseDefChain(
- state, uRead->get(), uConflictingWrite->get()) ||
+ state, uRead, uConflictingWrite) ||
hasEquivalentValueInReverseUseDefChain(
- state, uConflictingWrite->get(), uRead->get())) {
+ state, uConflictingWrite, uRead)) {
LLVM_DEBUG(
llvm::dbgs()
<< " no conflict: op bufferizes to element-wise access\n");
@@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
// Bufferization analyses.
//===----------------------------------------------------------------------===//
-// Find the values that define the contents of the given value.
+// Find the values that define the contents of the given opOperand.
const llvm::SetVector<Value> &
-OneShotAnalysisState::findDefinitionsCached(Value value) {
+OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
+ Value value = opOperand->get();
if (!cachedDefinitions.count(value))
- cachedDefinitions[value] = findDefinitions(value);
+ cachedDefinitions[value] = findDefinitions(opOperand);
return cachedDefinitions[value];
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 6801b68a853815..6c1087730ebba8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -553,7 +553,7 @@ Value linalg::bufferizeToAllocation(
Value alloc = createAllocationForTensor(
rewriter, op->getLoc(), operand->get(), options, memorySpace);
allocs.push_back(alloc);
- if (!state.findDefinitions(operand->get()).empty()) {
+ if (!state.findDefinitions(operand).empty()) {
// Initialize buffer with a copy of the operand data. Not needed if the
// tensor is uninitialized.
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 4776883ed95c5c..b710bde87f9f33 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -59,7 +59,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
- in->get(), /*condition=*/
+ in, /*condition=*/
[&](Value val) {
return val.getDefiningOp<tensor::EmptyOp>() &&
val.getType() == in->get().getType();
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 26434774730e1b..41ab9cd113b39a 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -465,3 +465,14 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @direct_use_of_tensor_empty
+func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
+ // CHECK-NOT: memref.alloc
+ %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+ %inserted_slice_1 = tensor.insert_slice %empty_1 into %arg0[0, 0, 0][5, 6, 64][1, 1, 1]
+ : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+ return %inserted_slice_1 : tensor<5x6x128xf32>
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list