[Mlir-commits] [mlir] dc7ad19 - [mlir][bufferize][NFC] Optimize read-only tensor detection
Matthias Springer
llvmlistbot at llvm.org
Thu Feb 9 00:07:25 PST 2023
Author: Matthias Springer
Date: 2023-02-09T09:07:14+01:00
New Revision: dc7ad194c77cbe5b49514bebcfcfc4afd9eb8439
URL: https://github.com/llvm/llvm-project/commit/dc7ad194c77cbe5b49514bebcfcfc4afd9eb8439
DIFF: https://github.com/llvm/llvm-project/commit/dc7ad194c77cbe5b49514bebcfcfc4afd9eb8439.diff
LOG: [mlir][bufferize][NFC] Optimize read-only tensor detection
Check alias sets instead of traversing the IR.
Differential Revision: https://reviews.llvm.org/D143500
Added:
Modified:
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 6420feb51fa1d..8a7d660f488f9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -724,62 +724,42 @@ static void annotateNonWritableTensor(Value value) {
}
}
-/// Check the reverse SSA use-def chain (following aliasing OpOperands) for
-/// non-writable tensor values. Stop searching when an out-of-place bufferized
-/// OpOperand was found (or when the OpOperand was not bufferized yet).
-/// `currentOpOperand` is assumed to be in-place, even if that decision was not
-/// materialized in `aliasInfo` yet.
-static bool
-hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
- const OneShotAnalysisState &state) {
- SmallVector<Value> worklist;
- worklist.push_back(value);
- while (!worklist.empty()) {
- Value nextVal = worklist.pop_back_val();
- if (!state.isWritable(nextVal)) {
- if (state.getOptions().printConflicts)
- annotateNonWritableTensor(nextVal);
- return true;
- }
-
- // If `nextVal` is not a BlockArgument: End of use-def chain reached.
- auto opResult = nextVal.dyn_cast<OpResult>();
- if (!opResult)
- continue;
-
- // Follow reverse SSA use-def chain.
- AliasingOpOperandList aliasingOpOperands =
- state.getAliasingOpOperands(opResult);
- for (OpOperand *opOperand : aliasingOpOperands)
- if (state.isInPlace(*opOperand) || currentOpOperand == opOperand)
- worklist.push_back(opOperand->get());
- }
- return false;
-}
-
/// Return true if bufferizing `operand` inplace would create a write to a
/// non-writable buffer.
static bool
wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
OneShotAnalysisState &state,
bool checkConsistencyOnly = false) {
- // Collect writes of all aliases of OpOperand and OpResult.
- DenseSet<OpOperand *> usesWrite;
- getAliasingInplaceWrites(usesWrite, operand.get(), state);
- for (OpResult result : state.getAliasingOpResults(operand)) {
- getAliasingInplaceWrites(usesWrite, result, state);
+ bool foundWrite =
+ !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);
+
+ if (!foundWrite) {
+ // Collect writes of all aliases of OpOperand and OpResult.
+ DenseSet<OpOperand *> usesWrite;
+ getAliasingInplaceWrites(usesWrite, operand.get(), state);
+ for (OpResult result : state.getAliasingOpResults(operand))
+ getAliasingInplaceWrites(usesWrite, result, state);
+ foundWrite = !usesWrite.empty();
}
- if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
- usesWrite.insert(&operand);
- // Assuming that `operand` bufferizes in-place: For each write (to each
- // alias), check if there is a non-writable tensor in the reverse SSA use-def
- // chain.
- for (OpOperand *uWrite : usesWrite) {
- if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, state)) {
- LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
- return true;
+ if (!foundWrite)
+ return false;
+
+ // Look for a read-only tensor among all aliases.
+ bool foundReadOnly = false;
+ auto checkReadOnly = [&](Value v) {
+ if (!state.isWritable(v)) {
+ foundReadOnly = true;
+ if (state.getOptions().printConflicts)
+ annotateNonWritableTensor(v);
}
+ };
+ state.applyOnAliases(operand.get(), checkReadOnly);
+ for (OpResult result : state.getAliasingOpResults(operand))
+ state.applyOnAliases(result, checkReadOnly);
+ if (foundReadOnly) {
+ LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
+ return true;
}
return false;
More information about the Mlir-commits
mailing list