[Mlir-commits] [mlir] [mlir] Add support for staged dataflow analyses (PR #192998)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 24 08:39:18 PDT 2026


https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/192998

>From c1c9166c23ec540941bcfdba5662d233bc473a49 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Tue, 14 Apr 2026 21:27:03 +0000
Subject: [PATCH 1/4] Add staged dataflow solver support

---
 .../include/mlir/Analysis/DataFlowFramework.h |  56 +++++-
 mlir/lib/Analysis/DataFlowFramework.cpp       |  40 +++-
 .../DataFlow/test-staged-analyses.mlir        |  50 +++++
 .../lib/Analysis/TestDataFlowFramework.cpp    | 177 +++++++++++++++++-
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 5 files changed, 318 insertions(+), 7 deletions(-)
 create mode 100644 mlir/test/Analysis/DataFlow/test-staged-analyses.mlir

diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 87ec01a918d90..721aaef668c2b 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -331,10 +331,19 @@ class DataFlowSolver {
   template <typename AnalysisT, typename... Args>
   AnalysisT *load(Args &&...args);
 
-  /// Initialize the children analyses starting from the provided top-level
-  /// operation and run the analysis until fixpoint.
+  /// Initialize the children analyses from scratch starting from the provided
+  /// top-level operation and run the analysis until fixpoint, discarding any
+  /// previously computed analysis state.
   LogicalResult initializeAndRun(Operation *top);
 
+  /// Initialize any loaded analyses that have not yet been initialized for the
+  /// current solver session, then continue running the solver until fixpoint.
+  ///
+  /// This preserves all previously computed analysis states and is intended for
+  /// staged analysis pipelines where later analyses depend on the converged
+  /// results of earlier ones.
+  LogicalResult initializeAndRunPendingAnalyses(Operation *top);
+
   /// Lookup an analysis state for the given lattice anchor. Returns null if one
   /// does not exist.
   template <typename StateT, typename AnchorT>
@@ -358,6 +367,10 @@ class DataFlowSolver {
   void eraseAllStates() {
     analysisStates.clear();
     equivalentAnchorMap.clear();
+    initializedAnalysisCount = 0;
+    analysisRoot = nullptr;
+    hasFailedRun = false;
+    worklist = std::queue<WorkItem>();
   }
 
   /// Get a uniqued lattice anchor instance. If one is not present, it is
@@ -432,6 +445,13 @@ class DataFlowSolver {
   const DataFlowConfig &getConfig() const { return config; }
 
 private:
+  /// Initialize analyses in the range [firstAnalysis, childAnalyses.size())
+  /// and continue running the solver until fixpoint.
+  LogicalResult initializeAndRunImpl(Operation *top, size_t firstAnalysis);
+
+  /// Drain the worklist to a fixpoint.
+  LogicalResult runToFixpoint();
+
   /// Configuration of the dataflow solver.
   DataFlowConfig config;
 
@@ -443,6 +463,17 @@ class DataFlowSolver {
   /// quickly degenerate to quadratic due to propagation of state updates.
   std::queue<WorkItem> worklist;
 
+  /// The root operation the current solver session was initialized with.
+  Operation *analysisRoot = nullptr;
+
+  /// The number of analyses that have been initialized for the current solver
+  /// session.
+  size_t initializedAnalysisCount = 0;
+
+  /// Whether the current solver session has failed and must be reset before
+  /// attempting an incremental run.
+  bool hasFailedRun = false;
+
   /// Type-erased instances of the children analyses.
   SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
 
@@ -764,9 +795,28 @@ bool DataFlowSolver::isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const {
 
 template <typename StateT, typename AnchorT>
 void DataFlowSolver::unionLatticeAnchors(AnchorT anchor, AnchorT other) {
+  // States are stored on equivalence-class leaders, so canonicalize before
+  // checking for materialized states. For example, if `B` is already
+  // equivalent to leader `A` and `A` owns the state, a later union(B, C) must
+  // observe `A`'s state. Returning early keeps redundant unions like
+  // union(A, B) as harmless no-ops.
+  LatticeAnchor lhs = getLeaderAnchorOrSelf<StateT>(LatticeAnchor(anchor));
+  LatticeAnchor rhs = getLeaderAnchorOrSelf<StateT>(LatticeAnchor(other));
+  if (lhs == rhs)
+    return;
+
+  auto hasStateForAnchor = [&](LatticeAnchor latticeAnchor) {
+    auto anchorIt = analysisStates.find(latticeAnchor);
+    return anchorIt != analysisStates.end() &&
+           anchorIt->second.contains(TypeID::get<StateT>());
+  };
+  assert(!hasStateForAnchor(lhs) && !hasStateForAnchor(rhs) &&
+         "cannot union lattice anchors after analysis states have been "
+         "materialized for the state type");
+
   llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
       equivalentAnchorMap[TypeID::get<StateT>()];
-  eqClass.unionSets(LatticeAnchor(anchor), LatticeAnchor(other));
+  eqClass.unionSets(lhs, rhs);
 }
 
 inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 258bcf312afc5..3763b2f9da233 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Value.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/iterator.h"
 #include "llvm/Config/abi-breaking.h"
 #include "llvm/Support/Casting.h"
@@ -110,6 +111,28 @@ Location LatticeAnchor::getLoc() const {
 //===----------------------------------------------------------------------===//
 
 LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
+  eraseAllStates();
+  return initializeAndRunImpl(top, /*firstAnalysis=*/0);
+}
+
+LogicalResult DataFlowSolver::initializeAndRunPendingAnalyses(Operation *top) {
+  if (hasFailedRun) {
+    return top->emitError("dataflow solver is in a failed state after a "
+                          "previous run; call 'initializeAndRun()' to "
+                          "restart or 'eraseAllStates()' before reusing it");
+  }
+  if (analysisRoot && analysisRoot != top) {
+    return top->emitError("dataflow solver can only be resumed with the same "
+                          "top-level operation used for the original run");
+  }
+  return initializeAndRunImpl(top, initializedAnalysisCount);
+}
+
+LogicalResult DataFlowSolver::initializeAndRunImpl(Operation *top,
+                                                   size_t firstAnalysis) {
+  analysisRoot = top;
+  hasFailedRun = true;
+
   // Enable enqueue to the worklist.
   isRunning = true;
   llvm::scope_exit guard([&]() { isRunning = false; });
@@ -121,17 +144,30 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
     config.setInterprocedural(false);
 
   // Initialize equivalent lattice anchors.
-  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+  for (DataFlowAnalysis &analysis :
+       llvm::make_pointee_range(llvm::drop_begin(childAnalyses,
+                                                 firstAnalysis))) {
     analysis.initializeEquivalentLatticeAnchor(top);
   }
 
   // Initialize the analyses.
-  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+  for (DataFlowAnalysis &analysis :
+       llvm::make_pointee_range(llvm::drop_begin(childAnalyses,
+                                                 firstAnalysis))) {
     DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
     if (failed(analysis.initialize(top)))
       return failure();
   }
 
+  if (failed(runToFixpoint()))
+    return failure();
+
+  initializedAnalysisCount = childAnalyses.size();
+  hasFailedRun = false;
+  return success();
+}
+
+LogicalResult DataFlowSolver::runToFixpoint() {
   // Run the analysis until fixpoint.
   // Iterate until all states are in some initialized state and the worklist
   // is exhausted.
diff --git a/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir b/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir
new file mode 100644
index 0000000000000..5da7ad8bf1274
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt -pass-pipeline='builtin.module(func.func(test-staged-analyses))' %s | FileCheck %s
+
+// CHECK-LABEL: func.func @linear()
+func.func @linear() {
+  // CHECK: "test.foo"() {bar_state = true, foo = 1 : ui64, foo_state = 1 : i64, tag = "annotate"} : () -> ()
+  "test.foo"() {tag = "annotate", foo = 1 : ui64} : () -> ()
+  // CHECK: "test.foo"() {bar_state = true, foo = 2 : ui64, foo_state = 3 : i64, tag = "annotate"} : () -> ()
+  "test.foo"() {tag = "annotate", foo = 2 : ui64} : () -> ()
+  // CHECK: "test.foo"() {bar_state = true, foo_state = 3 : i64, tag = "annotate"} : () -> ()
+  "test.foo"() {tag = "annotate"} : () -> ()
+  return
+}
+
+// This demonstrates why `BarAnalysis` should be run only after `FooAnalysis`
+// converges.
+//
+// Under the current `FooAnalysis` implementation:
+//   - entry op after-state is 0 xor 7 = 7
+//   - bb0 terminator after-state is 7 xor 1 = 6
+//   - when the join block is first visited, only bb0 has contributed, so the
+//     join op transiently sees 6 xor 2 = 4
+//   - once the other predecessor arrives, revisiting the join updates the
+//     final staged `foo_state` to 7 for the first op in the join block and it
+//     stays 7 for the following op
+//
+// But if a non-staged `BarAnalysis` observed bb2 after only bb0 had reached
+// it, bb2's first tagged op would transiently see 6 xor 2 = 4 and latch
+// `bar_state = false`, poisoning later points. The staged run below must use
+// only the converged `FooState`, so `bar_state` stays true.
+//
+// CHECK-LABEL: func.func @requires_staged_bar()
+func.func @requires_staged_bar() {
+  // CHECK: "test.branch"()[^bb{{[0-9]+}}, ^bb{{[0-9]+}}] {bar_state = true, foo = 7 : ui64, foo_state = 7 : i64, tag = "annotate"} : () -> ()
+  "test.branch"() [^bb0, ^bb2] {tag = "annotate", foo = 7 : ui64} : () -> ()
+
+^bb0:
+  // CHECK: "test.branch"()[^bb{{[0-9]+}}] {bar_state = true, foo = 1 : ui64, foo_state = 6 : i64, tag = "annotate"} : () -> ()
+  "test.branch"() [^bb1] {tag = "annotate", foo = 1 : ui64} : () -> ()
+
+^bb1:
+  // CHECK: "test.foo"() {bar_state = true, foo = 2 : ui64, foo_state = 7 : i64, tag = "annotate"} : () -> ()
+  "test.foo"() {tag = "annotate", foo = 2 : ui64} : () -> ()
+  // CHECK: "test.foo"() {bar_state = true, foo_state = 7 : i64, tag = "annotate"} : () -> ()
+  "test.foo"() {tag = "annotate"} : () -> ()
+  return
+
+^bb2:
+  // CHECK: "test.branch"()[^bb{{[0-9]+}}] {bar_state = true, foo = 2 : ui64, foo_state = 5 : i64, tag = "annotate"} : () -> ()
+  "test.branch"() [^bb1] {tag = "annotate", foo = 2 : ui64} : () -> ()
+}
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 4267fb42266ce..5da3a764f6a9a 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -8,12 +8,18 @@
 
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/Pass/Pass.h"
 #include <optional>
 
 using namespace mlir;
 
 namespace {
+constexpr char kTagAttrName[] = "tag";
+constexpr char kFooAttrName[] = "foo";
+constexpr char kFooStateAttrName[] = "foo_state";
+constexpr char kBarStateAttrName[] = "bar_state";
+
 /// This analysis state represents an integer that is XOR'd with other states.
 class FooState : public AnalysisState {
 public:
@@ -82,6 +88,70 @@ class FooAnalysis : public DataFlowAnalysis {
   void visitOperation(Operation *op);
 };
 
+/// This analysis state stores whether all previously observed `FooState`
+/// values at tagged program points along the CFG leading to the current point
+/// have been non-multiples of 4. Once the state becomes false at some point,
+/// all later points reachable from it also remain false.
+class BarState : public AnalysisState {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BarState)
+
+  using AnalysisState::AnalysisState;
+
+  bool isUninitialized() const { return !state; }
+
+  void print(raw_ostream &os) const override {
+    if (!state) {
+      os << "none";
+      return;
+    }
+    os << (*state ? "true" : "false");
+  }
+
+  ChangeResult join(const BarState &rhs) {
+    if (rhs.isUninitialized())
+      return ChangeResult::NoChange;
+    return join(rhs.getValue());
+  }
+
+  ChangeResult join(bool value) {
+    if (isUninitialized()) {
+      state = value;
+      return ChangeResult::Change;
+    }
+    bool newValue = *state && value;
+    if (newValue == *state)
+      return ChangeResult::NoChange;
+    state = newValue;
+    return ChangeResult::Change;
+  }
+
+  bool getValue() const { return *state; }
+
+private:
+  std::optional<bool> state;
+};
+
+/// This analysis is intended to be loaded after `FooAnalysis` has converged.
+/// It records whether every observed `FooState` on or before a given tagged
+/// program point has been non-divisible by 4. Because the state only ever
+/// transitions from true to false, observing a transient divisible-by-4
+/// `FooState` before `FooAnalysis` converges can permanently poison the
+/// result.
+class BarAnalysis : public DataFlowAnalysis {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BarAnalysis)
+
+  using DataFlowAnalysis::DataFlowAnalysis;
+
+  LogicalResult initialize(Operation *top) override;
+  LogicalResult visit(ProgramPoint *point) override;
+
+private:
+  void visitBlock(Block *block);
+  void visitOperation(Operation *op);
+};
+
 struct TestFooAnalysisPass
     : public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
@@ -90,6 +160,15 @@ struct TestFooAnalysisPass
 
   void runOnOperation() override;
 };
+
+struct TestStagedAnalysesPass
+    : public PassWrapper<TestStagedAnalysesPass, OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStagedAnalysesPass)
+
+  StringRef getArgument() const override { return "test-staged-analyses"; }
+
+  void runOnOperation() override;
+};
 } // namespace
 
 LogicalResult FooAnalysis::initialize(Operation *top) {
@@ -151,13 +230,76 @@ void FooAnalysis::visitOperation(Operation *op) {
   result |= state->set(*prevState);
 
   // Modify the state with the attribute, if specified.
-  if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
+  if (auto attr = op->getAttrOfType<IntegerAttr>(kFooAttrName)) {
     uint64_t value = attr.getUInt();
     result |= state->join(value);
   }
   propagateIfChanged(state, result);
 }
 
+LogicalResult BarAnalysis::initialize(Operation *top) {
+  if (top->getNumRegions() != 1)
+    return top->emitError("expected a single region top-level op");
+
+  if (top->getRegion(0).getBlocks().empty())
+    return top->emitError("expected at least one block in the region");
+
+  // Seed the entry state to true before observing any `FooState`.
+  (void)getOrCreate<BarState>(getProgramPointBefore(&top->getRegion(0).front()))
+      ->join(true);
+
+  for (Block &block : top->getRegion(0)) {
+    visitBlock(&block);
+    for (Operation &op : block) {
+      if (op.getNumRegions())
+        return op.emitError("unexpected op with regions");
+      visitOperation(&op);
+    }
+  }
+  return success();
+}
+
+LogicalResult BarAnalysis::visit(ProgramPoint *point) {
+  if (!point->isBlockStart())
+    visitOperation(point->getPrevOp());
+  else
+    visitBlock(point->getBlock());
+  return success();
+}
+
+void BarAnalysis::visitBlock(Block *block) {
+  if (block->isEntryBlock())
+    return;
+
+  ProgramPoint *point = getProgramPointBefore(block);
+  BarState *state = getOrCreate<BarState>(point);
+  ChangeResult result = ChangeResult::NoChange;
+  for (Block *pred : block->getPredecessors()) {
+    const BarState *predState = getOrCreateFor<BarState>(
+        point, getProgramPointAfter(pred->getTerminator()));
+    result |= state->join(*predState);
+  }
+  propagateIfChanged(state, result);
+}
+
+void BarAnalysis::visitOperation(Operation *op) {
+  ProgramPoint *point = getProgramPointAfter(op);
+  BarState *state = getOrCreate<BarState>(point);
+  ChangeResult result = ChangeResult::NoChange;
+
+  const BarState *prevState =
+      getOrCreateFor<BarState>(point, getProgramPointBefore(op));
+  result |= state->join(*prevState);
+
+  if (op->hasAttr(kTagAttrName)) {
+    const FooState *fooState = getOrCreateFor<FooState>(point, point);
+    if (fooState->isUninitialized())
+      return;
+    result |= state->join((fooState->getValue() & 0x3) != 0);
+  }
+  propagateIfChanged(state, result);
+}
+
 void TestFooAnalysisPass::runOnOperation() {
   func::FuncOp func = getOperation();
   DataFlowSolver solver;
@@ -169,7 +311,7 @@ void TestFooAnalysisPass::runOnOperation() {
   os << "function: @" << func.getSymName() << "\n";
 
   func.walk([&](Operation *op) {
-    auto tag = op->getAttrOfType<StringAttr>("tag");
+    auto tag = op->getAttrOfType<StringAttr>(kTagAttrName);
     if (!tag)
       return;
     const FooState *state =
@@ -179,8 +321,39 @@ void TestFooAnalysisPass::runOnOperation() {
   });
 }
 
+void TestStagedAnalysesPass::runOnOperation() {
+  func::FuncOp func = getOperation();
+  Builder builder(func.getContext());
+
+  DataFlowSolver solver;
+  solver.load<FooAnalysis>();
+  if (failed(solver.initializeAndRun(func)))
+    return signalPassFailure();
+  solver.load<BarAnalysis>();
+  if (failed(solver.initializeAndRunPendingAnalyses(func)))
+    return signalPassFailure();
+
+  func.walk([&](Operation *op) {
+    if (!op->hasAttr(kTagAttrName))
+      return;
+
+    ProgramPoint *point = solver.getProgramPointAfter(op);
+    const FooState *fooState = solver.lookupState<FooState>(point);
+    const BarState *barState = solver.lookupState<BarState>(point);
+    assert(fooState && !fooState->isUninitialized());
+    assert(barState && !barState->isUninitialized());
+
+    op->setAttr(kFooStateAttrName,
+                builder.getI64IntegerAttr(fooState->getValue()));
+    op->setAttr(kBarStateAttrName, builder.getBoolAttr(barState->getValue()));
+  });
+}
+
 namespace mlir {
 namespace test {
 void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
+void registerTestStagedAnalysesPass() {
+  PassRegistration<TestStagedAnalysesPass>();
+}
 } // namespace test
 } // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 48b8c179bd1b0..c4754b3a08551 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -99,6 +99,7 @@ void registerTestDynamicPipelinePass();
 void registerTestRemarkPass();
 void registerTestEmulateNarrowTypePass();
 void registerTestFooAnalysisPass();
+void registerTestStagedAnalysesPass();
 void registerTestComposeSubView();
 void registerTestMultiBuffering();
 void registerTestIRVisitorsPass();
@@ -247,6 +248,7 @@ static void registerTestPasses() {
   mlir::test::registerTestRemarkPass();
   mlir::test::registerTestEmulateNarrowTypePass();
   mlir::test::registerTestFooAnalysisPass();
+  mlir::test::registerTestStagedAnalysesPass();
   mlir::test::registerTestComposeSubView();
   mlir::test::registerTestMultiBuffering();
   mlir::test::registerTestIRVisitorsPass();

>From b94eb2841340c3451b83d20bdb4af57147f277e9 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Mon, 20 Apr 2026 15:40:42 +0000
Subject: [PATCH 2/4] Format staged dataflow solver changes

---
 mlir/lib/Analysis/DataFlowFramework.cpp | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 3763b2f9da233..ae9340799dfee 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -11,8 +11,8 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Value.h"
-#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/iterator.h"
 #include "llvm/Config/abi-breaking.h"
 #include "llvm/Support/Casting.h"
@@ -144,16 +144,14 @@ LogicalResult DataFlowSolver::initializeAndRunImpl(Operation *top,
     config.setInterprocedural(false);
 
   // Initialize equivalent lattice anchors.
-  for (DataFlowAnalysis &analysis :
-       llvm::make_pointee_range(llvm::drop_begin(childAnalyses,
-                                                 firstAnalysis))) {
+  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(
+           llvm::drop_begin(childAnalyses, firstAnalysis))) {
     analysis.initializeEquivalentLatticeAnchor(top);
   }
 
   // Initialize the analyses.
-  for (DataFlowAnalysis &analysis :
-       llvm::make_pointee_range(llvm::drop_begin(childAnalyses,
-                                                 firstAnalysis))) {
+  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(
+           llvm::drop_begin(childAnalyses, firstAnalysis))) {
     DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
     if (failed(analysis.initialize(top)))
       return failure();

>From 20ff2371521e6bcebaf328286906fbed8008ac98 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Wed, 22 Apr 2026 18:18:53 +0000
Subject: [PATCH 3/4] Address PR review: asserts, drop top param, preserve
 state on re-init

- initializeAndRunPendingAnalyses: use asserts instead of diagnostics for
  programming errors, remove top parameter and reuse stored analysisRoot.
- initializeAndRun: reset initializedAnalysisCount and analysisRoot instead
  of eraseAllStates, preserving the ability to call initializeAndRun on
  different entrypoints without losing prior state.
---
 mlir/include/mlir/Analysis/DataFlowFramework.h  |  2 +-
 mlir/lib/Analysis/DataFlowFramework.cpp         | 17 +++++++----------
 .../test/lib/Analysis/TestDataFlowFramework.cpp |  2 +-
 3 files changed, 9 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 721aaef668c2b..c8918203ea4df 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -342,7 +342,7 @@ class DataFlowSolver {
   /// This preserves all previously computed analysis states and is intended for
   /// staged analysis pipelines where later analyses depend on the converged
   /// results of earlier ones.
-  LogicalResult initializeAndRunPendingAnalyses(Operation *top);
+  LogicalResult initializeAndRunPendingAnalyses();
 
   /// Lookup an analysis state for the given lattice anchor. Returns null if one
   /// does not exist.
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index ae9340799dfee..bdab7069a186f 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -111,21 +111,18 @@ Location LatticeAnchor::getLoc() const {
 //===----------------------------------------------------------------------===//
 
 LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
-  eraseAllStates();
+  initializedAnalysisCount = 0;
+  analysisRoot = top;
   return initializeAndRunImpl(top, /*firstAnalysis=*/0);
 }
 
-LogicalResult DataFlowSolver::initializeAndRunPendingAnalyses(Operation *top) {
-  if (hasFailedRun) {
-    return top->emitError("dataflow solver is in a failed state after a "
+LogicalResult DataFlowSolver::initializeAndRunPendingAnalyses() {
+  assert(!hasFailedRun && "dataflow solver is in a failed state after a "
                           "previous run; call 'initializeAndRun()' to "
                           "restart or 'eraseAllStates()' before reusing it");
-  }
-  if (analysisRoot && analysisRoot != top) {
-    return top->emitError("dataflow solver can only be resumed with the same "
-                          "top-level operation used for the original run");
-  }
-  return initializeAndRunImpl(top, initializedAnalysisCount);
+  assert(analysisRoot && "dataflow solver has not been run yet; call "
+                         "'initializeAndRun()' first");
+  return initializeAndRunImpl(analysisRoot, initializedAnalysisCount);
 }
 
 LogicalResult DataFlowSolver::initializeAndRunImpl(Operation *top,
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 5da3a764f6a9a..57d2d8e08bf11 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -330,7 +330,7 @@ void TestStagedAnalysesPass::runOnOperation() {
   if (failed(solver.initializeAndRun(func)))
     return signalPassFailure();
   solver.load<BarAnalysis>();
-  if (failed(solver.initializeAndRunPendingAnalyses(func)))
+  if (failed(solver.initializeAndRunPendingAnalyses()))
     return signalPassFailure();
 
   func.walk([&](Operation *op) {

>From 6db41d18ba498f5d97a0f297438fc2b172230066 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 24 Apr 2026 14:58:16 +0000
Subject: [PATCH 4/4] Add analysis filter to initializeAndRun and TypeID to
 DataFlowAnalysis

Replace the staged initializeAndRunPendingAnalyses API with an optional
analysisFilter parameter on initializeAndRun. The filter is a predicate
on DataFlowAnalysis& that controls which analyses are initialized; the
fixpoint loop processes all enqueued work items regardless.

Add a TypeID member to DataFlowAnalysis, set during load(), exposed via
getTypeID(). This enables classof/isa/IsaPred for analysis classes.
---
 .../include/mlir/Analysis/DataFlowFramework.h | 77 ++++++-------------
 mlir/lib/Analysis/DataFlowFramework.cpp       | 48 ++++--------
 .../lib/Analysis/TestDataFlowFramework.cpp    | 10 ++-
 3 files changed, 46 insertions(+), 89 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index c8918203ea4df..23e61f4890232 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -22,6 +22,7 @@
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Support/Compiler.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/Support/TypeName.h"
 #include <queue>
 #include <tuple>
@@ -331,18 +332,17 @@ class DataFlowSolver {
   template <typename AnalysisT, typename... Args>
   AnalysisT *load(Args &&...args);
 
-  /// Initialize the children analyses from scratch starting from the provided
-  /// top-level operation and run the analysis until fixpoint, discarding any
-  /// previously computed analysis state.
-  LogicalResult initializeAndRun(Operation *top);
-
-  /// Initialize any loaded analyses that have not yet been initialized for the
-  /// current solver session, then continue running the solver until fixpoint.
+  /// Initialize analyses starting from the provided top-level operation and
+  /// run the analysis until fixpoint.
   ///
-  /// This preserves all previously computed analysis states and is intended for
-  /// staged analysis pipelines where later analyses depend on the converged
-  /// results of earlier ones.
-  LogicalResult initializeAndRunPendingAnalyses();
+  /// An optional \p analysisFilter predicate restricts which analyses are
+  /// initialized.  When no filter is given every loaded analysis is
+  /// (re-)initialized.  The fixpoint loop always processes all enqueued work
+  /// items regardless of the filter.
+  LogicalResult
+  initializeAndRun(Operation *top,
+                   llvm::function_ref<bool(DataFlowAnalysis &)> analysisFilter =
+                       nullptr);
 
   /// Lookup an analysis state for the given lattice anchor. Returns null if one
   /// does not exist.
@@ -367,10 +367,6 @@ class DataFlowSolver {
   void eraseAllStates() {
     analysisStates.clear();
     equivalentAnchorMap.clear();
-    initializedAnalysisCount = 0;
-    analysisRoot = nullptr;
-    hasFailedRun = false;
-    worklist = std::queue<WorkItem>();
   }
 
   /// Get a uniqued lattice anchor instance. If one is not present, it is
@@ -445,13 +441,6 @@ class DataFlowSolver {
   const DataFlowConfig &getConfig() const { return config; }
 
 private:
-  /// Initialize analyses in the range [firstAnalysis, childAnalyses.size())
-  /// and continue running the solver until fixpoint.
-  LogicalResult initializeAndRunImpl(Operation *top, size_t firstAnalysis);
-
-  /// Drain the worklist to a fixpoint.
-  LogicalResult runToFixpoint();
-
   /// Configuration of the dataflow solver.
   DataFlowConfig config;
 
@@ -463,17 +452,6 @@ class DataFlowSolver {
   /// quickly degenerate to quadratic due to propagation of state updates.
   std::queue<WorkItem> worklist;
 
-  /// The root operation the current solver session was initialized with.
-  Operation *analysisRoot = nullptr;
-
-  /// The number of analyses that have been initialized for the current solver
-  /// session.
-  size_t initializedAnalysisCount = 0;
-
-  /// Whether the current solver session has failed and must be reset before
-  /// attempting an incremental run.
-  bool hasFailedRun = false;
-
   /// Type-erased instances of the children analyses.
   SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
 
@@ -661,6 +639,12 @@ class DataFlowAnalysis {
   /// necessarily identical under the corrensponding lattice type.
   virtual void initializeEquivalentLatticeAnchor(Operation *top) {}
 
+  /// Return the TypeID of the concrete analysis class. Valid only after
+  /// `DataFlowSolver::load<AnalysisT>` has returned; must not be called from
+  /// the analysis constructor body because the TypeID is set by `load` after
+  /// construction.
+  TypeID getTypeID() const { return analysisTypeID; }
+
 protected:
   /// Create a dependency between the given analysis state and lattice anchor
   /// on this analysis.
@@ -736,6 +720,11 @@ class DataFlowAnalysis {
   /// The parent data-flow solver.
   DataFlowSolver &solver;
 
+  /// The TypeID of the concrete analysis class. Set by
+  /// `DataFlowSolver::load` after construction; not available during the
+  /// analysis constructor.
+  TypeID analysisTypeID;
+
   /// Allow the data-flow solver to access the internals of this class.
   friend class DataFlowSolver;
 };
@@ -743,6 +732,7 @@ class DataFlowAnalysis {
 template <typename AnalysisT, typename... Args>
 AnalysisT *DataFlowSolver::load(Args &&...args) {
   childAnalyses.emplace_back(new AnalysisT(*this, std::forward<Args>(args)...));
+  childAnalyses.back()->analysisTypeID = TypeID::get<AnalysisT>();
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   childAnalyses.back()->debugName = llvm::getTypeName<AnalysisT>();
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -795,28 +785,9 @@ bool DataFlowSolver::isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const {
 
 template <typename StateT, typename AnchorT>
 void DataFlowSolver::unionLatticeAnchors(AnchorT anchor, AnchorT other) {
-  // States are stored on equivalence-class leaders, so canonicalize before
-  // checking for materialized states. For example, if `B` is already
-  // equivalent to leader `A` and `A` owns the state, a later union(B, C) must
-  // observe `A`'s state. Returning early keeps redundant unions like
-  // union(A, B) as harmless no-ops.
-  LatticeAnchor lhs = getLeaderAnchorOrSelf<StateT>(LatticeAnchor(anchor));
-  LatticeAnchor rhs = getLeaderAnchorOrSelf<StateT>(LatticeAnchor(other));
-  if (lhs == rhs)
-    return;
-
-  auto hasStateForAnchor = [&](LatticeAnchor latticeAnchor) {
-    auto anchorIt = analysisStates.find(latticeAnchor);
-    return anchorIt != analysisStates.end() &&
-           anchorIt->second.contains(TypeID::get<StateT>());
-  };
-  assert(!hasStateForAnchor(lhs) && !hasStateForAnchor(rhs) &&
-         "cannot union lattice anchors after analysis states have been "
-         "materialized for the state type");
-
   llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
       equivalentAnchorMap[TypeID::get<StateT>()];
-  eqClass.unionSets(lhs, rhs);
+  eqClass.unionSets(LatticeAnchor(anchor), LatticeAnchor(other));
 }
 
 inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index bdab7069a186f..dfa26af05cc7b 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -11,7 +11,6 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Value.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/iterator.h"
 #include "llvm/Config/abi-breaking.h"
@@ -110,26 +109,8 @@ Location LatticeAnchor::getLoc() const {
 // DataFlowSolver
 //===----------------------------------------------------------------------===//
 
-LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
-  initializedAnalysisCount = 0;
-  analysisRoot = top;
-  return initializeAndRunImpl(top, /*firstAnalysis=*/0);
-}
-
-LogicalResult DataFlowSolver::initializeAndRunPendingAnalyses() {
-  assert(!hasFailedRun && "dataflow solver is in a failed state after a "
-                          "previous run; call 'initializeAndRun()' to "
-                          "restart or 'eraseAllStates()' before reusing it");
-  assert(analysisRoot && "dataflow solver has not been run yet; call "
-                         "'initializeAndRun()' first");
-  return initializeAndRunImpl(analysisRoot, initializedAnalysisCount);
-}
-
-LogicalResult DataFlowSolver::initializeAndRunImpl(Operation *top,
-                                                   size_t firstAnalysis) {
-  analysisRoot = top;
-  hasFailedRun = true;
-
+LogicalResult DataFlowSolver::initializeAndRun(
+    Operation *top, llvm::function_ref<bool(DataFlowAnalysis &)> analysisFilter) {
   // Enable enqueue to the worklist.
   isRunning = true;
   llvm::scope_exit guard([&]() { isRunning = false; });
@@ -138,31 +119,28 @@ LogicalResult DataFlowSolver::initializeAndRunImpl(Operation *top,
   llvm::scope_exit restoreInterprocedural(
       [&]() { config.setInterprocedural(isInterprocedural); });
   if (isInterprocedural && !top->hasTrait<OpTrait::SymbolTable>())
-    config.setInterprocedural(false);
+  config.setInterprocedural(false);
+
+  auto shouldInitialize = [&](DataFlowAnalysis &analysis) {
+    return !analysisFilter || analysisFilter(analysis);
+  };
 
   // Initialize equivalent lattice anchors.
-  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(
-           llvm::drop_begin(childAnalyses, firstAnalysis))) {
+  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+    if (!shouldInitialize(analysis))
+      continue;
     analysis.initializeEquivalentLatticeAnchor(top);
   }
 
   // Initialize the analyses.
-  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(
-           llvm::drop_begin(childAnalyses, firstAnalysis))) {
+  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+    if (!shouldInitialize(analysis))
+      continue;
     DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
     if (failed(analysis.initialize(top)))
       return failure();
   }
 
-  if (failed(runToFixpoint()))
-    return failure();
-
-  initializedAnalysisCount = childAnalyses.size();
-  hasFailedRun = false;
-  return success();
-}
-
-LogicalResult DataFlowSolver::runToFixpoint() {
   // Run the analysis until fixpoint.
   // Iterate until all states are in some initialized state and the worklist
   // is exhausted.
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 57d2d8e08bf11..9af7e205aaee9 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -80,6 +80,10 @@ class FooAnalysis : public DataFlowAnalysis {
 
   using DataFlowAnalysis::DataFlowAnalysis;
 
+  static bool classof(const DataFlowAnalysis *a) {
+    return a->getTypeID() == TypeID::get<FooAnalysis>();
+  }
+
   LogicalResult initialize(Operation *top) override;
   LogicalResult visit(ProgramPoint *point) override;
 
@@ -144,6 +148,10 @@ class BarAnalysis : public DataFlowAnalysis {
 
   using DataFlowAnalysis::DataFlowAnalysis;
 
+  static bool classof(const DataFlowAnalysis *a) {
+    return a->getTypeID() == TypeID::get<BarAnalysis>();
+  }
+
   LogicalResult initialize(Operation *top) override;
   LogicalResult visit(ProgramPoint *point) override;
 
@@ -330,7 +338,7 @@ void TestStagedAnalysesPass::runOnOperation() {
   if (failed(solver.initializeAndRun(func)))
     return signalPassFailure();
   solver.load<BarAnalysis>();
-  if (failed(solver.initializeAndRunPendingAnalyses()))
+  if (failed(solver.initializeAndRun(func, llvm::IsaPred<BarAnalysis>)))
     return signalPassFailure();
 
   func.walk([&](Operation *op) {



More information about the Mlir-commits mailing list