[Mlir-commits] [mlir] [mlir][bufferization] Add "bottom-up from terminators" analysis heuristic (PR #83964)

Han-Chung Wang llvmlistbot at llvm.org
Tue Mar 19 14:29:02 PDT 2024


================
@@ -1094,41 +1095,98 @@ static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) {
   equivalenceAnalysis(ops, state);
 }
 
-LogicalResult OneShotAnalysisState::analyzeOp(Operation *op,
-                                              const DominanceInfo &domInfo) {
-  // Collect ops so we can build our own reverse traversal.
-  SmallVector<Operation *> ops;
-  op->walk([&](Operation *op) {
-    // No tensors => no buffers.
-    if (!hasTensorSemantics(op))
+/// "Bottom-up from terminators" heuristic.
+static SmallVector<Operation *>
+bottomUpFromTerminatorsHeuristic(Operation *op,
+                                 const OneShotAnalysisState &state) {
+  SetVector<Operation *> traversedOps;
+
+  // Find region terminators.
+  op->walk<WalkOrder::PostOrder>([&](RegionBranchTerminatorOpInterface term) {
+    if (!traversedOps.insert(term))
       return;
-    ops.push_back(op);
+    // Follow the reverse SSA use-def chain from each yielded value as long as
+    // we stay within the same region.
+    SmallVector<OpResult> worklist;
+    for (Value v : term->getOperands()) {
+      if (!isa<TensorType>(v.getType()))
+        continue;
+      auto opResult = dyn_cast<OpResult>(v);
+      if (!opResult)
+        continue;
+      worklist.push_back(opResult);
+    }
+    while (!worklist.empty()) {
+      OpResult opResult = worklist.pop_back_val();
+      Operation *defOp = opResult.getDefiningOp();
+      if (!traversedOps.insert(defOp))
+        continue;
+      if (!term->getParentRegion()->findAncestorOpInRegion(*defOp))
+        continue;
+      AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
+      for (auto alias : aliases) {
+        Value v = alias.opOperand->get();
+        if (!isa<TensorType>(v.getType()))
+          continue;
+        auto opResult = dyn_cast<OpResult>(v);
+        if (!opResult)
+          continue;
+        worklist.push_back(opResult);
+      }
+    }
   });
 
-  if (getOptions().analysisFuzzerSeed) {
-    // This is a fuzzer. For testing purposes only. Randomize the order in which
-    // operations are analyzed. The bufferization quality is likely worse, but
-    // we want to make sure that no assertions are triggered anywhere.
-    std::mt19937 g(getOptions().analysisFuzzerSeed);
-    llvm::shuffle(ops.begin(), ops.end(), g);
-  }
+  // Analyze traversed ops, then all remaining ops.
+  SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end());
+  op->walk<WalkOrder::PostOrder, ReverseIterator>([&](Operation *op) {
+    if (!traversedOps.contains(op) && hasTensorSemantics(op))
+      result.push_back(op);
+  });
+  return result;
+}
 
+LogicalResult OneShotAnalysisState::analyzeOp(Operation *op,
+                                              const DominanceInfo &domInfo) {
   OneShotBufferizationOptions::AnalysisHeuristic heuristic =
       getOptions().analysisHeuristic;
-  if (heuristic == OneShotBufferizationOptions::AnalysisHeuristic::BottomUp) {
-    // Default: Walk ops in reverse for better interference analysis.
-    for (Operation *op : reverse(ops))
-      if (failed(analyzeSingleOp(op, domInfo)))
-        return failure();
-  } else if (heuristic ==
-             OneShotBufferizationOptions::AnalysisHeuristic::TopDown) {
-    for (Operation *op : ops)
-      if (failed(analyzeSingleOp(op, domInfo)))
-        return failure();
+
+  SmallVector<Operation *> orderedOps;
+  if (heuristic ==
+      OneShotBufferizationOptions::AnalysisHeuristic::BottomUpFromTerminators) {
----------------
hanhanW wrote:

I think we can use switch-case here, and it can capture issues if people forget updating this place.

https://github.com/llvm/llvm-project/pull/83964


More information about the Mlir-commits mailing list