[Mlir-commits] [mlir] [mlir] Add support for staged dataflow analyses (PR #192998)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 20 08:17:00 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Max191

<details>
<summary>Changes</summary>

Adds support for reusing the dataflow solver after a previous set of loaded analyses has been run to a fixpoint. This allows running a set of analyses to convergence before running a dependent analysis that requires the previous analyses to be converged. The new solver function is called `initializeAndRunPendingAnalyses`, and it runs any analyses that have been loaded since the last time the solver was run. The old `initializeAndRun` function is kept, and maintains the same functionality as before, so existing uses of the solver will remain unaffected.

A new analysis and test pass is also added, which illustrates how this new staged analysis callback can be useful. The example analysis, called `BarAnalysis`, depends on the converged state of the `FooAnalysis`. The Bar analysis is a forward analysis that tracks, for each program point, whether any of the preceding program points hold a `foo_state` that is divisible by 4. In the example test, the control flow graph looks like the following:

```
  entry-block
   /       \
bb0         bb2
    \     /
      bb1
```
The `foo_state` of `bb1` depends on the `foo_state` of `bb0` and `bb2`. If the solver goes through `bb0->bb1` before `bb2->bb1`, then there is an intermediate stage in the analyses where the state of `bb1` could be divisible by 4, even though the final state of `bb1` will not be divisible by 4 in the converged state. If the `BarAnalysis` runs on `bb1` in this intermediate state, then it will get stuck with the "divisible by 4" state, and the analysis will not yield the desired results.

This PR ensures that the `BarAnalysis` will see the correct state `foo_state`, because the `FooAnalysis` will fully run to a fixpoint before the `BarAnalysis` is loaded, initialized, and run.

The Foo and Bar analyses are just trivial examples, but this pattern is useful when there are analyses that can be made more effective by using complementary analyses like integer range/divisibility analyses.

Assisted-by: Codex (gpt-5.4)

---
Full diff: https://github.com/llvm/llvm-project/pull/192998.diff


5 Files Affected:

- (modified) mlir/include/mlir/Analysis/DataFlowFramework.h (+53-3) 
- (modified) mlir/lib/Analysis/DataFlowFramework.cpp (+38-2) 
- (added) mlir/test/Analysis/DataFlow/test-staged-analyses.mlir (+50) 
- (modified) mlir/test/lib/Analysis/TestDataFlowFramework.cpp (+175-2) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 87ec01a918d90..721aaef668c2b 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -331,10 +331,19 @@ class DataFlowSolver {
   template <typename AnalysisT, typename... Args>
   AnalysisT *load(Args &&...args);
 
-  /// Initialize the children analyses starting from the provided top-level
-  /// operation and run the analysis until fixpoint.
+  /// Initialize the children analyses from scratch starting from the provided
+  /// top-level operation and run the analysis until fixpoint, discarding any
+  /// previously computed analysis state.
   LogicalResult initializeAndRun(Operation *top);
 
+  /// Initialize any loaded analyses that have not yet been initialized for the
+  /// current solver session, then continue running the solver until fixpoint.
+  ///
+  /// This preserves all previously computed analysis states and is intended for
+  /// staged analysis pipelines where later analyses depend on the converged
+  /// results of earlier ones.
+  LogicalResult initializeAndRunPendingAnalyses(Operation *top);
+
   /// Lookup an analysis state for the given lattice anchor. Returns null if one
   /// does not exist.
   template <typename StateT, typename AnchorT>
@@ -358,6 +367,10 @@ class DataFlowSolver {
   void eraseAllStates() {
     analysisStates.clear();
     equivalentAnchorMap.clear();
+    initializedAnalysisCount = 0;
+    analysisRoot = nullptr;
+    hasFailedRun = false;
+    worklist = std::queue<WorkItem>();
   }
 
   /// Get a uniqued lattice anchor instance. If one is not present, it is
@@ -432,6 +445,13 @@ class DataFlowSolver {
   const DataFlowConfig &getConfig() const { return config; }
 
 private:
+  /// Initialize analyses in the range [firstAnalysis, childAnalyses.size())
+  /// and continue running the solver until fixpoint.
+  LogicalResult initializeAndRunImpl(Operation *top, size_t firstAnalysis);
+
+  /// Drain the worklist to a fixpoint.
+  LogicalResult runToFixpoint();
+
   /// Configuration of the dataflow solver.
   DataFlowConfig config;
 
@@ -443,6 +463,17 @@ class DataFlowSolver {
   /// quickly degenerate to quadratic due to propagation of state updates.
   std::queue<WorkItem> worklist;
 
+  /// The root operation the current solver session was initialized with.
+  Operation *analysisRoot = nullptr;
+
+  /// The number of analyses that have been initialized for the current solver
+  /// session.
+  size_t initializedAnalysisCount = 0;
+
+  /// Whether the current solver session has failed and must be reset before
+  /// attempting an incremental run.
+  bool hasFailedRun = false;
+
   /// Type-erased instances of the children analyses.
   SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
 
@@ -764,9 +795,28 @@ bool DataFlowSolver::isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const {
 
 template <typename StateT, typename AnchorT>
 void DataFlowSolver::unionLatticeAnchors(AnchorT anchor, AnchorT other) {
+  // States are stored on equivalence-class leaders, so canonicalize before
+  // checking for materialized states. For example, if `B` is already
+  // equivalent to leader `A` and `A` owns the state, a later union(B, C) must
+  // observe `A`'s state. Returning early keeps redundant unions like
+  // union(A, B) as harmless no-ops.
+  LatticeAnchor lhs = getLeaderAnchorOrSelf<StateT>(LatticeAnchor(anchor));
+  LatticeAnchor rhs = getLeaderAnchorOrSelf<StateT>(LatticeAnchor(other));
+  if (lhs == rhs)
+    return;
+
+  auto hasStateForAnchor = [&](LatticeAnchor latticeAnchor) {
+    auto anchorIt = analysisStates.find(latticeAnchor);
+    return anchorIt != analysisStates.end() &&
+           anchorIt->second.contains(TypeID::get<StateT>());
+  };
+  assert(!hasStateForAnchor(lhs) && !hasStateForAnchor(rhs) &&
+         "cannot union lattice anchors after analysis states have been "
+         "materialized for the state type");
+
   llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
       equivalentAnchorMap[TypeID::get<StateT>()];
-  eqClass.unionSets(LatticeAnchor(anchor), LatticeAnchor(other));
+  eqClass.unionSets(lhs, rhs);
 }
 
 inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 258bcf312afc5..3763b2f9da233 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/iterator.h"
 #include "llvm/Config/abi-breaking.h"
 #include "llvm/Support/Casting.h"
@@ -110,6 +111,28 @@ Location LatticeAnchor::getLoc() const {
 //===----------------------------------------------------------------------===//
 
 LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
+  eraseAllStates();
+  return initializeAndRunImpl(top, /*firstAnalysis=*/0);
+}
+
+LogicalResult DataFlowSolver::initializeAndRunPendingAnalyses(Operation *top) {
+  if (hasFailedRun) {
+    return top->emitError("dataflow solver is in a failed state after a "
+                          "previous run; call 'initializeAndRun()' to "
+                          "restart or 'eraseAllStates()' before reusing it");
+  }
+  if (analysisRoot && analysisRoot != top) {
+    return top->emitError("dataflow solver can only be resumed with the same "
+                          "top-level operation used for the original run");
+  }
+  return initializeAndRunImpl(top, initializedAnalysisCount);
+}
+
+LogicalResult DataFlowSolver::initializeAndRunImpl(Operation *top,
+                                                   size_t firstAnalysis) {
+  analysisRoot = top;
+  hasFailedRun = true;
+
   // Enable enqueue to the worklist.
   isRunning = true;
   llvm::scope_exit guard([&]() { isRunning = false; });
@@ -121,17 +144,30 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
     config.setInterprocedural(false);
 
   // Initialize equivalent lattice anchors.
-  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+  for (DataFlowAnalysis &analysis :
+       llvm::make_pointee_range(llvm::drop_begin(childAnalyses,
+                                                 firstAnalysis))) {
     analysis.initializeEquivalentLatticeAnchor(top);
   }
 
   // Initialize the analyses.
-  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+  for (DataFlowAnalysis &analysis :
+       llvm::make_pointee_range(llvm::drop_begin(childAnalyses,
+                                                 firstAnalysis))) {
     DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
     if (failed(analysis.initialize(top)))
       return failure();
   }
 
+  if (failed(runToFixpoint()))
+    return failure();
+
+  initializedAnalysisCount = childAnalyses.size();
+  hasFailedRun = false;
+  return success();
+}
+
+LogicalResult DataFlowSolver::runToFixpoint() {
   // Run the analysis until fixpoint.
   // Iterate until all states are in some initialized state and the worklist
   // is exhausted.
diff --git a/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir b/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir
new file mode 100644
index 0000000000000..5da7ad8bf1274
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt -pass-pipeline='builtin.module(func.func(test-staged-analyses))' %s | FileCheck %s
+
+// CHECK-LABEL: func.func @linear()
+func.func @linear() {
+  // CHECK: "test.foo"() {bar_state = true, foo = 1 : ui64, foo_state = 1 : i64, tag = "annotate"} : () -> ()
+  "test.foo"() {tag = "annotate", foo = 1 : ui64} : () -> ()
+  // CHECK: "test.foo"() {bar_state = true, foo = 2 : ui64, foo_state = 3 : i64, tag = "annotate"} : () -> ()
+  "test.foo"() {tag = "annotate", foo = 2 : ui64} : () -> ()
+  // CHECK: "test.foo"() {bar_state = true, foo_state = 3 : i64, tag = "annotate"} : () -> ()
+  "test.foo"() {tag = "annotate"} : () -> ()
+  return
+}
+
+// This demonstrates why `BarAnalysis` should be run only after `FooAnalysis`
+// converges.
+//
+// Under the current `FooAnalysis` implementation:
+//   - entry op after-state is 0 xor 7 = 7
+//   - bb0 terminator after-state is 7 xor 1 = 6
+//   - when the join block is first visited, only bb0 has contributed, so the
+//     join op transiently sees 6 xor 2 = 4
+//   - once the other predecessor arrives, revisiting the join updates the
+//     final staged `foo_state` to 7 for the first op in the join block and it
+//     stays 7 for the following op
+//
+// But if a non-staged `BarAnalysis` observed bb2 after only bb0 had reached
+// it, bb2's first tagged op would transiently see 6 xor 2 = 4 and latch
+// `bar_state = false`, poisoning later points. The staged run below must use
+// only the converged `FooState`, so `bar_state` stays true.
+//
+// CHECK-LABEL: func.func @requires_staged_bar()
+func.func @requires_staged_bar() {
+  // CHECK: "test.branch"()[^bb{{[0-9]+}}, ^bb{{[0-9]+}}] {bar_state = true, foo = 7 : ui64, foo_state = 7 : i64, tag = "annotate"} : () -> ()
+  "test.branch"() [^bb0, ^bb2] {tag = "annotate", foo = 7 : ui64} : () -> ()
+
+^bb0:
+  // CHECK: "test.branch"()[^bb{{[0-9]+}}] {bar_state = true, foo = 1 : ui64, foo_state = 6 : i64, tag = "annotate"} : () -> ()
+  "test.branch"() [^bb1] {tag = "annotate", foo = 1 : ui64} : () -> ()
+
+^bb1:
+  // CHECK: "test.foo"() {bar_state = true, foo = 2 : ui64, foo_state = 7 : i64, tag = "annotate"} : () -> ()
+  "test.foo"() {tag = "annotate", foo = 2 : ui64} : () -> ()
+  // CHECK: "test.foo"() {bar_state = true, foo_state = 7 : i64, tag = "annotate"} : () -> ()
+  "test.foo"() {tag = "annotate"} : () -> ()
+  return
+
+^bb2:
+  // CHECK: "test.branch"()[^bb{{[0-9]+}}] {bar_state = true, foo = 2 : ui64, foo_state = 5 : i64, tag = "annotate"} : () -> ()
+  "test.branch"() [^bb1] {tag = "annotate", foo = 2 : ui64} : () -> ()
+}
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 4267fb42266ce..5da3a764f6a9a 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -8,12 +8,18 @@
 
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/Pass/Pass.h"
 #include <optional>
 
 using namespace mlir;
 
 namespace {
+constexpr char kTagAttrName[] = "tag";
+constexpr char kFooAttrName[] = "foo";
+constexpr char kFooStateAttrName[] = "foo_state";
+constexpr char kBarStateAttrName[] = "bar_state";
+
 /// This analysis state represents an integer that is XOR'd with other states.
 class FooState : public AnalysisState {
 public:
@@ -82,6 +88,70 @@ class FooAnalysis : public DataFlowAnalysis {
   void visitOperation(Operation *op);
 };
 
+/// This analysis state stores whether all previously observed `FooState`
+/// values at tagged program points along the CFG leading to the current point
+/// have been non-multiples of 4. Once the state becomes false at some point,
+/// all later points reachable from it also remain false.
+class BarState : public AnalysisState {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BarState)
+
+  using AnalysisState::AnalysisState;
+
+  bool isUninitialized() const { return !state; }
+
+  void print(raw_ostream &os) const override {
+    if (!state) {
+      os << "none";
+      return;
+    }
+    os << (*state ? "true" : "false");
+  }
+
+  ChangeResult join(const BarState &rhs) {
+    if (rhs.isUninitialized())
+      return ChangeResult::NoChange;
+    return join(rhs.getValue());
+  }
+
+  ChangeResult join(bool value) {
+    if (isUninitialized()) {
+      state = value;
+      return ChangeResult::Change;
+    }
+    bool newValue = *state && value;
+    if (newValue == *state)
+      return ChangeResult::NoChange;
+    state = newValue;
+    return ChangeResult::Change;
+  }
+
+  bool getValue() const { return *state; }
+
+private:
+  std::optional<bool> state;
+};
+
+/// This analysis is intended to be loaded after `FooAnalysis` has converged.
+/// It records whether every observed `FooState` on or before a given tagged
+/// program point has been non-divisible by 4. Because the state only ever
+/// transitions from true to false, observing a transient divisible-by-4
+/// `FooState` before `FooAnalysis` converges can permanently poison the
+/// result.
+class BarAnalysis : public DataFlowAnalysis {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BarAnalysis)
+
+  using DataFlowAnalysis::DataFlowAnalysis;
+
+  LogicalResult initialize(Operation *top) override;
+  LogicalResult visit(ProgramPoint *point) override;
+
+private:
+  void visitBlock(Block *block);
+  void visitOperation(Operation *op);
+};
+
 struct TestFooAnalysisPass
     : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
@@ -90,6 +160,15 @@ struct TestFooAnalysisPass
 
   void runOnOperation() override;
 };
+
+struct TestStagedAnalysesPass
+    : public PassWrapper<TestStagedAnalysesPass, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStagedAnalysesPass)
+
+  StringRef getArgument() const override { return "test-staged-analyses"; }
+
+  void runOnOperation() override;
+};
 } // namespace
 
 LogicalResult FooAnalysis::initialize(Operation *top) {
@@ -151,13 +230,76 @@ void FooAnalysis::visitOperation(Operation *op) {
   result |= state->set(*prevState);
 
   // Modify the state with the attribute, if specified.
-  if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
+  if (auto attr = op->getAttrOfType<IntegerAttr>(kFooAttrName)) {
     uint64_t value = attr.getUInt();
     result |= state->join(value);
   }
   propagateIfChanged(state, result);
 }
 
+LogicalResult BarAnalysis::initialize(Operation *top) {
+  if (top->getNumRegions() != 1)
+    return top->emitError("expected a single region top-level op");
+
+  if (top->getRegion(0).getBlocks().empty())
+    return top->emitError("expected at least one block in the region");
+
+  // Seed the entry state to true before observing any `FooState`.
+  (void)getOrCreate<BarState>(getProgramPointBefore(&top->getRegion(0).front()))
+      ->join(true);
+
+  for (Block &block : top->getRegion(0)) {
+    visitBlock(&block);
+    for (Operation &op : block) {
+      if (op.getNumRegions())
+        return op.emitError("unexpected op with regions");
+      visitOperation(&op);
+    }
+  }
+  return success();
+}
+
+LogicalResult BarAnalysis::visit(ProgramPoint *point) {
+  if (!point->isBlockStart())
+    visitOperation(point->getPrevOp());
+  else
+    visitBlock(point->getBlock());
+  return success();
+}
+
+void BarAnalysis::visitBlock(Block *block) {
+  if (block->isEntryBlock())
+    return;
+
+  ProgramPoint *point = getProgramPointBefore(block);
+  BarState *state = getOrCreate<BarState>(point);
+  ChangeResult result = ChangeResult::NoChange;
+  for (Block *pred : block->getPredecessors()) {
+    const BarState *predState = getOrCreateFor<BarState>(
+        point, getProgramPointAfter(pred->getTerminator()));
+    result |= state->join(*predState);
+  }
+  propagateIfChanged(state, result);
+}
+
+void BarAnalysis::visitOperation(Operation *op) {
+  ProgramPoint *point = getProgramPointAfter(op);
+  BarState *state = getOrCreate<BarState>(point);
+  ChangeResult result = ChangeResult::NoChange;
+
+  const BarState *prevState =
+      getOrCreateFor<BarState>(point, getProgramPointBefore(op));
+  result |= state->join(*prevState);
+
+  if (op->hasAttr(kTagAttrName)) {
+    const FooState *fooState = getOrCreateFor<FooState>(point, point);
+    if (fooState->isUninitialized())
+      return;
+    result |= state->join((fooState->getValue() & 0x3) != 0);
+  }
+  propagateIfChanged(state, result);
+}
+
 void TestFooAnalysisPass::runOnOperation() {
   func::FuncOp func = getOperation();
   DataFlowSolver solver;
@@ -169,7 +311,7 @@ void TestFooAnalysisPass::runOnOperation() {
   os << "function: @" << func.getSymName() << "\n";
 
   func.walk([&](Operation *op) {
-    auto tag = op->getAttrOfType<StringAttr>("tag");
+    auto tag = op->getAttrOfType<StringAttr>(kTagAttrName);
     if (!tag)
       return;
     const FooState *state =
@@ -179,8 +321,39 @@ void TestFooAnalysisPass::runOnOperation() {
   });
 }
 
+void TestStagedAnalysesPass::runOnOperation() {
+  func::FuncOp func = getOperation();
+  Builder builder(func.getContext());
+
+  DataFlowSolver solver;
+  solver.load<FooAnalysis>();
+  if (failed(solver.initializeAndRun(func)))
+    return signalPassFailure();
+  solver.load<BarAnalysis>();
+  if (failed(solver.initializeAndRunPendingAnalyses(func)))
+    return signalPassFailure();
+
+  func.walk([&](Operation *op) {
+    if (!op->hasAttr(kTagAttrName))
+      return;
+
+    ProgramPoint *point = solver.getProgramPointAfter(op);
+    const FooState *fooState = solver.lookupState<FooState>(point);
+    const BarState *barState = solver.lookupState<BarState>(point);
+    assert(fooState && !fooState->isUninitialized());
+    assert(barState && !barState->isUninitialized());
+
+    op->setAttr(kFooStateAttrName,
+                builder.getI64IntegerAttr(fooState->getValue()));
+    op->setAttr(kBarStateAttrName, builder.getBoolAttr(barState->getValue()));
+  });
+}
+
 namespace mlir {
 namespace test {
 void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
+void registerTestStagedAnalysesPass() {
+  PassRegistration<TestStagedAnalysesPass>();
+}
 } // namespace test
 } // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 48b8c179bd1b0..c4754b3a08551 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -99,6 +99,7 @@ void registerTestDynamicPipelinePass();
 void registerTestRemarkPass();
 void registerTestEmulateNarrowTypePass();
 void registerTestFooAnalysisPass();
+void registerTestStagedAnalysesPass();
 void registerTestComposeSubView();
 void registerTestMultiBuffering();
 void registerTestIRVisitorsPass();
@@ -247,6 +248,7 @@ static void registerTestPasses() {
   mlir::test::registerTestRemarkPass();
   mlir::test::registerTestEmulateNarrowTypePass();
   mlir::test::registerTestFooAnalysisPass();
+  mlir::test::registerTestStagedAnalysesPass();
   mlir::test::registerTestComposeSubView();
   mlir::test::registerTestMultiBuffering();
   mlir::test::registerTestIRVisitorsPass();

``````````

</details>


https://github.com/llvm/llvm-project/pull/192998


More information about the Mlir-commits mailing list