[Mlir-commits] [mlir] 1f479c1 - [mlir][bufferization] Improve findValueInReverseUseDefChain signature
Matthias Springer
llvmlistbot at llvm.org
Mon May 15 06:32:08 PDT 2023
Author: Matthias Springer
Date: 2023-05-15T15:31:56+02:00
New Revision: 1f479c1e46d111a6f001cf4ee24290f60f13257d
URL: https://github.com/llvm/llvm-project/commit/1f479c1e46d111a6f001cf4ee24290f60f13257d
DIFF: https://github.com/llvm/llvm-project/commit/1f479c1e46d111a6f001cf4ee24290f60f13257d.diff
LOG: [mlir][bufferization] Improve findValueInReverseUseDefChain signature
Instead of passing traversal options as a long list of arguments, store them in a TraversalConfig object and pass that object.
Differential Revision: https://reviews.llvm.org/D143927
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 2dbd113547e91..45d705c444a7e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -392,6 +392,23 @@ struct BufferizationOptions {
/// Return `true` if the given value is a BlockArgument of a func::FuncOp.
bool isFunctionArgument(Value value);
+/// Traversal parameters for `findValueInReverseUseDefChain`.
+struct TraversalConfig {
+ /// Specifies if leaves (that do not have further OpOperands to follow)
+ /// should be returned even if they do not match the specified filter.
+ bool alwaysIncludeLeaves = true;
+
+ /// Specifies whether out-of-place/undecided OpOperands should be followed.
+ bool followInPlaceOnly = false;
+
+ /// Specifies whether non-equivalent OpOperands should be followed.
+ bool followEquivalentOnly = false;
+
+ /// Specifies whether unknown/non-bufferizable/ops not included in the
+ /// OpFilter of BufferizationOptions should be followed.
+ bool followUnknownOps = false;
+};
+
/// AnalysisState provides a variety of helper functions for dealing with
/// tensor values.
class AnalysisState {
@@ -437,9 +454,8 @@ class AnalysisState {
/// `condition` evaluates to true. OpOperands of such matching Values are not
/// traversed any further.
///
- /// When reaching the end of a chain (BlockArgument or Value without aliasing
- /// OpOperands), also return the last Value of that chain if
- /// `alwaysIncludeLeaves` is set.
+ /// When reaching the end of a chain, also return the last Value of that
+ /// chain if `config.alwaysIncludeLeaves` is set.
///
/// Example:
///
@@ -457,10 +473,11 @@ class AnalysisState {
/// starting the traversal from Value 1, the resulting SetVector is:
/// { 2, 7, 8, 5 }
///
- /// If `followEquivalentOnly` is set, only equivalent OpOperands are selected.
+ /// Additional stopping conditions for the traversal can be specified in
+ /// `config`.
SetVector<Value> findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition,
- bool followEquivalentOnly = false, bool alwaysIncludeLeaves = true) const;
+ TraversalConfig config = TraversalConfig()) const;
/// Find the values that may define the contents of the given value at
/// runtime. A block argument is always a definition. An OpResult is a
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index cb69a9e5879c0..712693ddd53a1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -367,10 +367,7 @@ BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
BufferizableOpInterface
BufferizationOptions::dynCastBufferizableOp(Value value) const {
- if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
- if (isOpAllowed(bufferizableOp.getOperation()))
- return bufferizableOp;
- return nullptr;
+ return dynCastBufferizableOp(getOwnerOfValue(value));
}
void BufferizationOptions::setFunctionBoundaryTypeConversion(
@@ -500,7 +497,7 @@ bool AnalysisState::isValueRead(Value value) const {
// further.
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition,
- bool followEquivalentOnly, bool alwaysIncludeLeaves) const {
+ TraversalConfig config) const {
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
@@ -512,7 +509,7 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
}
if (llvm::isa<BlockArgument>(value)) {
- if (alwaysIncludeLeaves)
+ if (config.alwaysIncludeLeaves)
result.insert(value);
continue;
}
@@ -520,26 +517,43 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
OpResult opResult = llvm::cast<OpResult>(value);
BufferizableOpInterface bufferizableOp =
options.dynCastBufferizableOp(opResult.getDefiningOp());
- AliasingOpOperandList aliases = getAliasingOpOperands(opResult);
+ if (!config.followUnknownOps && !bufferizableOp) {
+ // Stop iterating if `followUnknownOps` is unset and the op is either
+ // not bufferizable or excluded in the OpFilter.
+ if (config.alwaysIncludeLeaves)
+ result.insert(value);
+ continue;
+ }
- // Stop iterating in either one of these cases:
- // * The current op is not bufferizable or excluded in the filter.
- // * There are no OpOperands to follow.
- if (!bufferizableOp || aliases.getNumAliases() == 0) {
- if (alwaysIncludeLeaves)
+ AliasingOpOperandList aliases = getAliasingOpOperands(opResult);
+ if (aliases.getNumAliases() == 0) {
+ // The traversal ends naturally if there are no more OpOperands that
+ // could be followed.
+ if (config.alwaysIncludeLeaves)
result.insert(value);
continue;
}
for (AliasingOpOperand a : aliases) {
- if (followEquivalentOnly && a.relation != BufferRelation::Equivalent) {
+ if (config.followEquivalentOnly &&
+ a.relation != BufferRelation::Equivalent) {
// Stop iterating if `followEquivalentOnly` is set but the alias is not
// equivalent.
- if (alwaysIncludeLeaves)
+ if (config.alwaysIncludeLeaves)
result.insert(value);
} else {
workingSet.insert(a.opOperand->get());
}
+
+ if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) {
+ // Stop iterating if `followInPlaceOnly` is set but the alias is
+ // out-of-place.
+ if (config.alwaysIncludeLeaves)
+ result.insert(value);
+ continue;
+ }
+
+ workingSet.insert(a.opOperand->get());
}
}
@@ -548,9 +562,10 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
// Find the values that define the contents of the given value.
llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
+ TraversalConfig config;
+ config.alwaysIncludeLeaves = false;
return findValueInReverseUseDefChain(
- value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
- /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
+ value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
}
AnalysisState::AnalysisState(const BufferizationOptions &options)
@@ -927,12 +942,12 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
return false;
return state.bufferizesToMemoryWrite(v);
};
+ TraversalConfig config;
+ config.alwaysIncludeLeaves = false;
for (AliasingOpOperand alias : opOperands) {
if (!state
.findValueInReverseUseDefChain(alias.opOperand->get(),
- isMemoryWriteInsideOp,
- /*followEquivalentOnly=*/false,
- /*alwaysIncludeLeaves=*/false)
+ isMemoryWriteInsideOp, config)
.empty())
return true;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 58475d225ce8b..76d424867af61 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -132,10 +132,13 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
// on the path.
+ TraversalConfig config;
+ config.followEquivalentOnly = true;
+ config.alwaysIncludeLeaves = false;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
operand.get(), /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
- /*followEquivalentOnly=*/true, /*alwaysIncludeLeaves=*/false);
+ config);
for (Value v : emptyTensors) {
Operation *emptyTensorOp = v.getDefiningOp();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 6da512699cc7b..a9f05b21282de 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -775,11 +775,8 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
// Find the values that define the contents of the given value.
const llvm::SetVector<Value> &
OneShotAnalysisState::findDefinitionsCached(Value value) {
- if (!cachedDefinitions.count(value)) {
- cachedDefinitions[value] = findValueInReverseUseDefChain(
- value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
- /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
- }
+ if (!cachedDefinitions.count(value))
+ cachedDefinitions[value] = findDefinitions(value);
return cachedDefinitions[value];
}
More information about the Mlir-commits
mailing list