[Mlir-commits] [mlir] [mlir][bufferization] Add "bottom-up from terminators" analysis heuristic (PR #83964)
Han-Chung Wang
llvmlistbot at llvm.org
Tue Mar 19 14:29:02 PDT 2024
================
@@ -1094,41 +1095,98 @@ static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) {
equivalenceAnalysis(ops, state);
}
-LogicalResult OneShotAnalysisState::analyzeOp(Operation *op,
- const DominanceInfo &domInfo) {
- // Collect ops so we can build our own reverse traversal.
- SmallVector<Operation *> ops;
- op->walk([&](Operation *op) {
- // No tensors => no buffers.
- if (!hasTensorSemantics(op))
+/// "Bottom-up from terminators" heuristic.
+static SmallVector<Operation *>
+bottomUpFromTerminatorsHeuristic(Operation *op,
+ const OneShotAnalysisState &state) {
+ SetVector<Operation *> traversedOps;
+
+ // Find region terminators.
+ op->walk<WalkOrder::PostOrder>([&](RegionBranchTerminatorOpInterface term) {
+ if (!traversedOps.insert(term))
return;
- ops.push_back(op);
+ // Follow the reverse SSA use-def chain from each yielded value as long as
+ // we stay within the same region.
+ SmallVector<OpResult> worklist;
+ for (Value v : term->getOperands()) {
+ if (!isa<TensorType>(v.getType()))
+ continue;
+ auto opResult = dyn_cast<OpResult>(v);
+ if (!opResult)
+ continue;
+ worklist.push_back(opResult);
+ }
+ while (!worklist.empty()) {
+ OpResult opResult = worklist.pop_back_val();
+ Operation *defOp = opResult.getDefiningOp();
+ if (!traversedOps.insert(defOp))
+ continue;
+ if (!term->getParentRegion()->findAncestorOpInRegion(*defOp))
+ continue;
+ AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
+ for (auto alias : aliases) {
+ Value v = alias.opOperand->get();
+ if (!isa<TensorType>(v.getType()))
+ continue;
+ auto opResult = dyn_cast<OpResult>(v);
+ if (!opResult)
+ continue;
+ worklist.push_back(opResult);
+ }
+ }
});
- if (getOptions().analysisFuzzerSeed) {
- // This is a fuzzer. For testing purposes only. Randomize the order in which
- // operations are analyzed. The bufferization quality is likely worse, but
- // we want to make sure that no assertions are triggered anywhere.
- std::mt19937 g(getOptions().analysisFuzzerSeed);
- llvm::shuffle(ops.begin(), ops.end(), g);
- }
+ // Analyze traversed ops, then all remaining ops.
+ SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end());
+ op->walk<WalkOrder::PostOrder, ReverseIterator>([&](Operation *op) {
+ if (!traversedOps.contains(op) && hasTensorSemantics(op))
+ result.push_back(op);
+ });
+ return result;
+}
+LogicalResult OneShotAnalysisState::analyzeOp(Operation *op,
+ const DominanceInfo &domInfo) {
OneShotBufferizationOptions::AnalysisHeuristic heuristic =
getOptions().analysisHeuristic;
- if (heuristic == OneShotBufferizationOptions::AnalysisHeuristic::BottomUp) {
- // Default: Walk ops in reverse for better interference analysis.
- for (Operation *op : reverse(ops))
- if (failed(analyzeSingleOp(op, domInfo)))
- return failure();
- } else if (heuristic ==
- OneShotBufferizationOptions::AnalysisHeuristic::TopDown) {
- for (Operation *op : ops)
- if (failed(analyzeSingleOp(op, domInfo)))
- return failure();
+
+ SmallVector<Operation *> orderedOps;
+ if (heuristic ==
+ OneShotBufferizationOptions::AnalysisHeuristic::BottomUpFromTerminators) {
----------------
hanhanW wrote:
I think we can use switch-case here, and it can capture issues if people forget updating this place.
https://github.com/llvm/llvm-project/pull/83964
More information about the Mlir-commits
mailing list