[Mlir-commits] [mlir] 36ec848 - [mlir][linalg][bufferize][NFC] Change findValueInReverseUseDefChain signature
Matthias Springer
llvmlistbot at llvm.org
Thu Oct 21 01:34:15 PDT 2021
Author: Matthias Springer
Date: 2021-10-21T17:34:05+09:00
New Revision: 36ec848dc7186b9713bb69ade134f9b2b7d65070
URL: https://github.com/llvm/llvm-project/commit/36ec848dc7186b9713bb69ade134f9b2b7d65070
DIFF: https://github.com/llvm/llvm-project/commit/36ec848dc7186b9713bb69ade134f9b2b7d65070.diff
LOG: [mlir][linalg][bufferize][NFC] Change findValueInReverseUseDefChain signature
This commit is in preparation for scf.if support.
* `condition` in findValueInReverseUseDefChain takes a Value instead of OpOperand*.
* Return a SetVector<Value> instead of a single Value. This SetVector always contains exactly one Value at the moment.
Differential Revision: https://reviews.llvm.org/D111928
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index be18cd135a47..0655d97b4ccb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -836,28 +836,37 @@ void BufferizationAliasInfo::bufferizeOutOfPlace(OpResult result) {
/// Starting from `value`, follow the use-def chain in reverse, always selecting
/// the corresponding aliasing OpOperand. Try to find and return a Value for
-/// which `condition` evaluates to true for the aliasing OpOperand. Return an
-/// empty Value if no such Value was found. If `returnLast`, return the last
-/// Value (at the end of the chain), even if it does not satisfy the condition.
-static Value
+/// which `condition` evaluates to true.
+///
+/// When reaching the end of the chain (BlockArgument or Value without aliasing
+/// OpOperands), return the last Value of the chain.
+///
+/// Note: The returned SetVector contains exactly one element.
+static llvm::SetVector<Value>
findValueInReverseUseDefChain(Value value,
- std::function<bool(OpOperand &)> condition,
- bool returnLast = false) {
- while (value.isa<OpResult>()) {
- auto opResult = value.cast<OpResult>();
+ std::function<bool(Value)> condition) {
+ llvm::SetVector<Value> result, workingSet;
+ workingSet.insert(value);
+
+ while (!workingSet.empty()) {
+ Value value = workingSet.pop_back_val();
+ if (condition(value) || value.isa<BlockArgument>()) {
+ result.insert(value);
+ continue;
+ }
+
+ OpResult opResult = value.cast<OpResult>();
SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
- assert(opOperands.size() <= 1 && "more than 1 OpOperand not supported yet");
- if (opOperands.empty())
- // No aliasing OpOperand. This could be an unsupported op or an op without
- // a tensor arg such as InitTensorOp. This is the end of the chain.
- return returnLast ? value : Value();
- OpOperand *opOperand = opOperands.front();
- if (condition(*opOperand))
- return value;
- value = opOperand->get();
+ if (opOperands.empty()) {
+ result.insert(value);
+ continue;
+ }
+
+ assert(opOperands.size() == 1 && "multiple OpOperands not supported yet");
+ workingSet.insert(opOperands.front()->get());
}
- // Value is a BlockArgument. Reached the end of the chain.
- return returnLast ? value : Value();
+
+ return result;
}
/// Find the Value (result) of the last preceding write of a given Value.
@@ -866,20 +875,41 @@ findValueInReverseUseDefChain(Value value,
/// Furthermore, BlockArguments are also assumed to be writes. There is no
/// analysis across block boundaries.
static Value findLastPrecedingWrite(Value value) {
- return findValueInReverseUseDefChain(value, bufferizesToMemoryWrite, true);
+ SetVector<Value> result =
+ findValueInReverseUseDefChain(value, [](Value value) {
+ Operation *op = value.getDefiningOp();
+ if (!op)
+ return true;
+ if (!hasKnownBufferizationAliasingBehavior(op))
+ return true;
+
+ SmallVector<OpOperand *> opOperands =
+ getAliasingOpOperand(value.cast<OpResult>());
+ assert(opOperands.size() <= 1 &&
+ "op with multiple aliasing OpOperands not expected");
+
+ if (opOperands.empty())
+ return true;
+
+ return bufferizesToMemoryWrite(*opOperands.front());
+ });
+ assert(result.size() == 1 && "expected exactly one result");
+ return result.front();
}
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
bool BufferizationAliasInfo::hasMatchingExtractSliceOp(
Value value, InsertSliceOp insertOp) const {
- return static_cast<bool>(
- findValueInReverseUseDefChain(value, [&](OpOperand &opOperand) {
- if (auto extractOp = dyn_cast<ExtractSliceOp>(opOperand.getOwner()))
- if (areEquivalentExtractSliceOps(extractOp, insertOp))
- return true;
- return false;
- }));
+ auto condition = [&](Value val) {
+ if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
+ if (areEquivalentExtractSliceOps(extractOp, insertOp))
+ return true;
+ return false;
+ };
+
+ return llvm::all_of(findValueInReverseUseDefChain(value, condition),
+ condition);
}
/// Given sets of uses and writes, return true if there is a RaW conflict under
More information about the Mlir-commits
mailing list