[Mlir-commits] [mlir] [mlir] Add support for staged dataflow analyses (PR #192998)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 24 08:39:18 PDT 2026
https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/192998
>From c1c9166c23ec540941bcfdba5662d233bc473a49 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Tue, 14 Apr 2026 21:27:03 +0000
Subject: [PATCH 1/4] Add staged dataflow solver support
---
.../include/mlir/Analysis/DataFlowFramework.h | 56 +++++-
mlir/lib/Analysis/DataFlowFramework.cpp | 40 +++-
.../DataFlow/test-staged-analyses.mlir | 50 +++++
.../lib/Analysis/TestDataFlowFramework.cpp | 177 +++++++++++++++++-
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
5 files changed, 318 insertions(+), 7 deletions(-)
create mode 100644 mlir/test/Analysis/DataFlow/test-staged-analyses.mlir
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 87ec01a918d90..721aaef668c2b 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -331,10 +331,19 @@ class DataFlowSolver {
template <typename AnalysisT, typename... Args>
AnalysisT *load(Args &&...args);
- /// Initialize the children analyses starting from the provided top-level
- /// operation and run the analysis until fixpoint.
+ /// Initialize the children analyses from scratch starting from the provided
+ /// top-level operation and run the analysis until fixpoint, discarding any
+ /// previously computed analysis state.
LogicalResult initializeAndRun(Operation *top);
+ /// Initialize any loaded analyses that have not yet been initialized for the
+ /// current solver session, then continue running the solver until fixpoint.
+ ///
+ /// This preserves all previously computed analysis states and is intended for
+ /// staged analysis pipelines where later analyses depend on the converged
+ /// results of earlier ones.
+ LogicalResult initializeAndRunPendingAnalyses(Operation *top);
+
/// Lookup an analysis state for the given lattice anchor. Returns null if one
/// does not exist.
template <typename StateT, typename AnchorT>
@@ -358,6 +367,10 @@ class DataFlowSolver {
void eraseAllStates() {
analysisStates.clear();
equivalentAnchorMap.clear();
+ initializedAnalysisCount = 0;
+ analysisRoot = nullptr;
+ hasFailedRun = false;
+ worklist = std::queue<WorkItem>();
}
/// Get a uniqued lattice anchor instance. If one is not present, it is
@@ -432,6 +445,13 @@ class DataFlowSolver {
const DataFlowConfig &getConfig() const { return config; }
private:
+ /// Initialize analyses in the range [firstAnalysis, childAnalyses.size())
+ /// and continue running the solver until fixpoint.
+ LogicalResult initializeAndRunImpl(Operation *top, size_t firstAnalysis);
+
+ /// Drain the worklist to a fixpoint.
+ LogicalResult runToFixpoint();
+
/// Configuration of the dataflow solver.
DataFlowConfig config;
@@ -443,6 +463,17 @@ class DataFlowSolver {
/// quickly degenerate to quadratic due to propagation of state updates.
std::queue<WorkItem> worklist;
+ /// The root operation the current solver session was initialized with.
+ Operation *analysisRoot = nullptr;
+
+ /// The number of analyses that have been initialized for the current solver
+ /// session.
+ size_t initializedAnalysisCount = 0;
+
+ /// Whether the current solver session has failed and must be reset before
+ /// attempting an incremental run.
+ bool hasFailedRun = false;
+
/// Type-erased instances of the children analyses.
SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
@@ -764,9 +795,28 @@ bool DataFlowSolver::isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const {
template <typename StateT, typename AnchorT>
void DataFlowSolver::unionLatticeAnchors(AnchorT anchor, AnchorT other) {
+ // States are stored on equivalence-class leaders, so canonicalize before
+ // checking for materialized states. For example, if `B` is already
+ // equivalent to leader `A` and `A` owns the state, a later union(B, C) must
+ // observe `A`'s state. Returning early keeps redundant unions like
+ // union(A, B) as harmless no-ops.
+ LatticeAnchor lhs = getLeaderAnchorOrSelf<StateT>(LatticeAnchor(anchor));
+ LatticeAnchor rhs = getLeaderAnchorOrSelf<StateT>(LatticeAnchor(other));
+ if (lhs == rhs)
+ return;
+
+ auto hasStateForAnchor = [&](LatticeAnchor latticeAnchor) {
+ auto anchorIt = analysisStates.find(latticeAnchor);
+ return anchorIt != analysisStates.end() &&
+ anchorIt->second.contains(TypeID::get<StateT>());
+ };
+ assert(!hasStateForAnchor(lhs) && !hasStateForAnchor(rhs) &&
+ "cannot union lattice anchors after analysis states have been "
+ "materialized for the state type");
+
llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
equivalentAnchorMap[TypeID::get<StateT>()];
- eqClass.unionSets(LatticeAnchor(anchor), LatticeAnchor(other));
+ eqClass.unionSets(lhs, rhs);
}
inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 258bcf312afc5..3763b2f9da233 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/iterator.h"
#include "llvm/Config/abi-breaking.h"
#include "llvm/Support/Casting.h"
@@ -110,6 +111,28 @@ Location LatticeAnchor::getLoc() const {
//===----------------------------------------------------------------------===//
LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
+ eraseAllStates();
+ return initializeAndRunImpl(top, /*firstAnalysis=*/0);
+}
+
+LogicalResult DataFlowSolver::initializeAndRunPendingAnalyses(Operation *top) {
+ if (hasFailedRun) {
+ return top->emitError("dataflow solver is in a failed state after a "
+ "previous run; call 'initializeAndRun()' to "
+ "restart or 'eraseAllStates()' before reusing it");
+ }
+ if (analysisRoot && analysisRoot != top) {
+ return top->emitError("dataflow solver can only be resumed with the same "
+ "top-level operation used for the original run");
+ }
+ return initializeAndRunImpl(top, initializedAnalysisCount);
+}
+
+LogicalResult DataFlowSolver::initializeAndRunImpl(Operation *top,
+ size_t firstAnalysis) {
+ analysisRoot = top;
+ hasFailedRun = true;
+
// Enable enqueue to the worklist.
isRunning = true;
llvm::scope_exit guard([&]() { isRunning = false; });
@@ -121,17 +144,30 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
config.setInterprocedural(false);
// Initialize equivalent lattice anchors.
- for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+ for (DataFlowAnalysis &analysis :
+ llvm::make_pointee_range(llvm::drop_begin(childAnalyses,
+ firstAnalysis))) {
analysis.initializeEquivalentLatticeAnchor(top);
}
// Initialize the analyses.
- for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+ for (DataFlowAnalysis &analysis :
+ llvm::make_pointee_range(llvm::drop_begin(childAnalyses,
+ firstAnalysis))) {
DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
if (failed(analysis.initialize(top)))
return failure();
}
+ if (failed(runToFixpoint()))
+ return failure();
+
+ initializedAnalysisCount = childAnalyses.size();
+ hasFailedRun = false;
+ return success();
+}
+
+LogicalResult DataFlowSolver::runToFixpoint() {
// Run the analysis until fixpoint.
// Iterate until all states are in some initialized state and the worklist
// is exhausted.
diff --git a/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir b/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir
new file mode 100644
index 0000000000000..5da7ad8bf1274
--- /dev/null
+++ b/mlir/test/Analysis/DataFlow/test-staged-analyses.mlir
@@ -0,0 +1,50 @@
+// RUN: mlir-opt -pass-pipeline='builtin.module(func.func(test-staged-analyses))' %s | FileCheck %s
+
+// CHECK-LABEL: func.func @linear()
+func.func @linear() {
+ // CHECK: "test.foo"() {bar_state = true, foo = 1 : ui64, foo_state = 1 : i64, tag = "annotate"} : () -> ()
+ "test.foo"() {tag = "annotate", foo = 1 : ui64} : () -> ()
+ // CHECK: "test.foo"() {bar_state = true, foo = 2 : ui64, foo_state = 3 : i64, tag = "annotate"} : () -> ()
+ "test.foo"() {tag = "annotate", foo = 2 : ui64} : () -> ()
+ // CHECK: "test.foo"() {bar_state = true, foo_state = 3 : i64, tag = "annotate"} : () -> ()
+ "test.foo"() {tag = "annotate"} : () -> ()
+ return
+}
+
+// This demonstrates why `BarAnalysis` should be run only after `FooAnalysis`
+// converges.
+//
+// Under the current `FooAnalysis` implementation:
+// - entry op after-state is 0 xor 7 = 7
+// - bb0 terminator after-state is 7 xor 1 = 6
+// - when the join block is first visited, only bb0 has contributed, so the
+// join op transiently sees 6 xor 2 = 4
+// - once the other predecessor arrives, revisiting the join updates the
+// final staged `foo_state` to 7 for the first op in the join block and it
+// stays 7 for the following op
+//
+// But if a non-staged `BarAnalysis` observed bb2 after only bb0 had reached
+// it, bb2's first tagged op would transiently see 6 xor 2 = 4 and latch
+// `bar_state = false`, poisoning later points. The staged run below must use
+// only the converged `FooState`, so `bar_state` stays true.
+//
+// CHECK-LABEL: func.func @requires_staged_bar()
+func.func @requires_staged_bar() {
+ // CHECK: "test.branch"()[^bb{{[0-9]+}}, ^bb{{[0-9]+}}] {bar_state = true, foo = 7 : ui64, foo_state = 7 : i64, tag = "annotate"} : () -> ()
+ "test.branch"() [^bb0, ^bb2] {tag = "annotate", foo = 7 : ui64} : () -> ()
+
+^bb0:
+ // CHECK: "test.branch"()[^bb{{[0-9]+}}] {bar_state = true, foo = 1 : ui64, foo_state = 6 : i64, tag = "annotate"} : () -> ()
+ "test.branch"() [^bb1] {tag = "annotate", foo = 1 : ui64} : () -> ()
+
+^bb1:
+ // CHECK: "test.foo"() {bar_state = true, foo = 2 : ui64, foo_state = 7 : i64, tag = "annotate"} : () -> ()
+ "test.foo"() {tag = "annotate", foo = 2 : ui64} : () -> ()
+ // CHECK: "test.foo"() {bar_state = true, foo_state = 7 : i64, tag = "annotate"} : () -> ()
+ "test.foo"() {tag = "annotate"} : () -> ()
+ return
+
+^bb2:
+ // CHECK: "test.branch"()[^bb{{[0-9]+}}] {bar_state = true, foo = 2 : ui64, foo_state = 5 : i64, tag = "annotate"} : () -> ()
+ "test.branch"() [^bb1] {tag = "annotate", foo = 2 : ui64} : () -> ()
+}
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 4267fb42266ce..5da3a764f6a9a 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -8,12 +8,18 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include <optional>
using namespace mlir;
namespace {
+constexpr char kTagAttrName[] = "tag";
+constexpr char kFooAttrName[] = "foo";
+constexpr char kFooStateAttrName[] = "foo_state";
+constexpr char kBarStateAttrName[] = "bar_state";
+
/// This analysis state represents an integer that is XOR'd with other states.
class FooState : public AnalysisState {
public:
@@ -82,6 +88,70 @@ class FooAnalysis : public DataFlowAnalysis {
void visitOperation(Operation *op);
};
+/// This analysis state stores whether all previously observed `FooState`
+/// values at tagged program points along the CFG leading to the current point
+/// have been non-multiples of 4. Once the state becomes false at some point,
+/// all later points reachable from it also remain false.
+class BarState : public AnalysisState {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BarState)
+
+ using AnalysisState::AnalysisState;
+
+ bool isUninitialized() const { return !state; }
+
+ void print(raw_ostream &os) const override {
+ if (!state) {
+ os << "none";
+ return;
+ }
+ os << (*state ? "true" : "false");
+ }
+
+ ChangeResult join(const BarState &rhs) {
+ if (rhs.isUninitialized())
+ return ChangeResult::NoChange;
+ return join(rhs.getValue());
+ }
+
+ ChangeResult join(bool value) {
+ if (isUninitialized()) {
+ state = value;
+ return ChangeResult::Change;
+ }
+ bool newValue = *state && value;
+ if (newValue == *state)
+ return ChangeResult::NoChange;
+ state = newValue;
+ return ChangeResult::Change;
+ }
+
+ bool getValue() const { return *state; }
+
+private:
+ std::optional<bool> state;
+};
+
+/// This analysis is intended to be loaded after `FooAnalysis` has converged.
+/// It records whether every observed `FooState` on or before a given tagged
+/// program point has been non-divisible by 4. Because the state only ever
+/// transitions from true to false, observing a transient divisible-by-4
+/// `FooState` before `FooAnalysis` converges can permanently poison the
+/// result.
+class BarAnalysis : public DataFlowAnalysis {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BarAnalysis)
+
+ using DataFlowAnalysis::DataFlowAnalysis;
+
+ LogicalResult initialize(Operation *top) override;
+ LogicalResult visit(ProgramPoint *point) override;
+
+private:
+ void visitBlock(Block *block);
+ void visitOperation(Operation *op);
+};
+
struct TestFooAnalysisPass
: public PassWrapper<TestFooAnalysisPass, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass)
@@ -90,6 +160,15 @@ struct TestFooAnalysisPass
void runOnOperation() override;
};
+
+struct TestStagedAnalysesPass
+ : public PassWrapper<TestStagedAnalysesPass, OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestStagedAnalysesPass)
+
+ StringRef getArgument() const override { return "test-staged-analyses"; }
+
+ void runOnOperation() override;
+};
} // namespace
LogicalResult FooAnalysis::initialize(Operation *top) {
@@ -151,13 +230,76 @@ void FooAnalysis::visitOperation(Operation *op) {
result |= state->set(*prevState);
// Modify the state with the attribute, if specified.
- if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
+ if (auto attr = op->getAttrOfType<IntegerAttr>(kFooAttrName)) {
uint64_t value = attr.getUInt();
result |= state->join(value);
}
propagateIfChanged(state, result);
}
+LogicalResult BarAnalysis::initialize(Operation *top) {
+ if (top->getNumRegions() != 1)
+ return top->emitError("expected a single region top-level op");
+
+ if (top->getRegion(0).getBlocks().empty())
+ return top->emitError("expected at least one block in the region");
+
+ // Seed the entry state to true before observing any `FooState`.
+ (void)getOrCreate<BarState>(getProgramPointBefore(&top->getRegion(0).front()))
+ ->join(true);
+
+ for (Block &block : top->getRegion(0)) {
+ visitBlock(&block);
+ for (Operation &op : block) {
+ if (op.getNumRegions())
+ return op.emitError("unexpected op with regions");
+ visitOperation(&op);
+ }
+ }
+ return success();
+}
+
+LogicalResult BarAnalysis::visit(ProgramPoint *point) {
+ if (!point->isBlockStart())
+ visitOperation(point->getPrevOp());
+ else
+ visitBlock(point->getBlock());
+ return success();
+}
+
+void BarAnalysis::visitBlock(Block *block) {
+ if (block->isEntryBlock())
+ return;
+
+ ProgramPoint *point = getProgramPointBefore(block);
+ BarState *state = getOrCreate<BarState>(point);
+ ChangeResult result = ChangeResult::NoChange;
+ for (Block *pred : block->getPredecessors()) {
+ const BarState *predState = getOrCreateFor<BarState>(
+ point, getProgramPointAfter(pred->getTerminator()));
+ result |= state->join(*predState);
+ }
+ propagateIfChanged(state, result);
+}
+
+void BarAnalysis::visitOperation(Operation *op) {
+ ProgramPoint *point = getProgramPointAfter(op);
+ BarState *state = getOrCreate<BarState>(point);
+ ChangeResult result = ChangeResult::NoChange;
+
+ const BarState *prevState =
+ getOrCreateFor<BarState>(point, getProgramPointBefore(op));
+ result |= state->join(*prevState);
+
+ if (op->hasAttr(kTagAttrName)) {
+ const FooState *fooState = getOrCreateFor<FooState>(point, point);
+ if (fooState->isUninitialized())
+ return;
+ result |= state->join((fooState->getValue() & 0x3) != 0);
+ }
+ propagateIfChanged(state, result);
+}
+
void TestFooAnalysisPass::runOnOperation() {
func::FuncOp func = getOperation();
DataFlowSolver solver;
@@ -169,7 +311,7 @@ void TestFooAnalysisPass::runOnOperation() {
os << "function: @" << func.getSymName() << "\n";
func.walk([&](Operation *op) {
- auto tag = op->getAttrOfType<StringAttr>("tag");
+ auto tag = op->getAttrOfType<StringAttr>(kTagAttrName);
if (!tag)
return;
const FooState *state =
@@ -179,8 +321,39 @@ void TestFooAnalysisPass::runOnOperation() {
});
}
+void TestStagedAnalysesPass::runOnOperation() {
+ func::FuncOp func = getOperation();
+ Builder builder(func.getContext());
+
+ DataFlowSolver solver;
+ solver.load<FooAnalysis>();
+ if (failed(solver.initializeAndRun(func)))
+ return signalPassFailure();
+ solver.load<BarAnalysis>();
+ if (failed(solver.initializeAndRunPendingAnalyses(func)))
+ return signalPassFailure();
+
+ func.walk([&](Operation *op) {
+ if (!op->hasAttr(kTagAttrName))
+ return;
+
+ ProgramPoint *point = solver.getProgramPointAfter(op);
+ const FooState *fooState = solver.lookupState<FooState>(point);
+ const BarState *barState = solver.lookupState<BarState>(point);
+ assert(fooState && !fooState->isUninitialized());
+ assert(barState && !barState->isUninitialized());
+
+ op->setAttr(kFooStateAttrName,
+ builder.getI64IntegerAttr(fooState->getValue()));
+ op->setAttr(kBarStateAttrName, builder.getBoolAttr(barState->getValue()));
+ });
+}
+
namespace mlir {
namespace test {
void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
+void registerTestStagedAnalysesPass() {
+ PassRegistration<TestStagedAnalysesPass>();
+}
} // namespace test
} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 48b8c179bd1b0..c4754b3a08551 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -99,6 +99,7 @@ void registerTestDynamicPipelinePass();
void registerTestRemarkPass();
void registerTestEmulateNarrowTypePass();
void registerTestFooAnalysisPass();
+void registerTestStagedAnalysesPass();
void registerTestComposeSubView();
void registerTestMultiBuffering();
void registerTestIRVisitorsPass();
@@ -247,6 +248,7 @@ static void registerTestPasses() {
mlir::test::registerTestRemarkPass();
mlir::test::registerTestEmulateNarrowTypePass();
mlir::test::registerTestFooAnalysisPass();
+ mlir::test::registerTestStagedAnalysesPass();
mlir::test::registerTestComposeSubView();
mlir::test::registerTestMultiBuffering();
mlir::test::registerTestIRVisitorsPass();
>From b94eb2841340c3451b83d20bdb4af57147f277e9 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Mon, 20 Apr 2026 15:40:42 +0000
Subject: [PATCH 2/4] Format staged dataflow solver changes
---
mlir/lib/Analysis/DataFlowFramework.cpp | 12 +++++-------
1 file changed, 5 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 3763b2f9da233..ae9340799dfee 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -11,8 +11,8 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
-#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/iterator.h"
#include "llvm/Config/abi-breaking.h"
#include "llvm/Support/Casting.h"
@@ -144,16 +144,14 @@ LogicalResult DataFlowSolver::initializeAndRunImpl(Operation *top,
config.setInterprocedural(false);
// Initialize equivalent lattice anchors.
- for (DataFlowAnalysis &analysis :
- llvm::make_pointee_range(llvm::drop_begin(childAnalyses,
- firstAnalysis))) {
+ for (DataFlowAnalysis &analysis : llvm::make_pointee_range(
+ llvm::drop_begin(childAnalyses, firstAnalysis))) {
analysis.initializeEquivalentLatticeAnchor(top);
}
// Initialize the analyses.
- for (DataFlowAnalysis &analysis :
- llvm::make_pointee_range(llvm::drop_begin(childAnalyses,
- firstAnalysis))) {
+ for (DataFlowAnalysis &analysis : llvm::make_pointee_range(
+ llvm::drop_begin(childAnalyses, firstAnalysis))) {
DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
if (failed(analysis.initialize(top)))
return failure();
>From 20ff2371521e6bcebaf328286906fbed8008ac98 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Wed, 22 Apr 2026 18:18:53 +0000
Subject: [PATCH 3/4] Address PR review: asserts, drop top param, preserve
state on re-init
- initializeAndRunPendingAnalyses: use asserts instead of diagnostics for
programming errors, remove top parameter and reuse stored analysisRoot.
- initializeAndRun: reset initializedAnalysisCount and analysisRoot instead
of eraseAllStates, preserving the ability to call initializeAndRun on
different entrypoints without losing prior state.
---
mlir/include/mlir/Analysis/DataFlowFramework.h | 2 +-
mlir/lib/Analysis/DataFlowFramework.cpp | 17 +++++++----------
.../test/lib/Analysis/TestDataFlowFramework.cpp | 2 +-
3 files changed, 9 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 721aaef668c2b..c8918203ea4df 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -342,7 +342,7 @@ class DataFlowSolver {
/// This preserves all previously computed analysis states and is intended for
/// staged analysis pipelines where later analyses depend on the converged
/// results of earlier ones.
- LogicalResult initializeAndRunPendingAnalyses(Operation *top);
+ LogicalResult initializeAndRunPendingAnalyses();
/// Lookup an analysis state for the given lattice anchor. Returns null if one
/// does not exist.
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index ae9340799dfee..bdab7069a186f 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -111,21 +111,18 @@ Location LatticeAnchor::getLoc() const {
//===----------------------------------------------------------------------===//
LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
- eraseAllStates();
+ initializedAnalysisCount = 0;
+ analysisRoot = top;
return initializeAndRunImpl(top, /*firstAnalysis=*/0);
}
-LogicalResult DataFlowSolver::initializeAndRunPendingAnalyses(Operation *top) {
- if (hasFailedRun) {
- return top->emitError("dataflow solver is in a failed state after a "
+LogicalResult DataFlowSolver::initializeAndRunPendingAnalyses() {
+ assert(!hasFailedRun && "dataflow solver is in a failed state after a "
"previous run; call 'initializeAndRun()' to "
"restart or 'eraseAllStates()' before reusing it");
- }
- if (analysisRoot && analysisRoot != top) {
- return top->emitError("dataflow solver can only be resumed with the same "
- "top-level operation used for the original run");
- }
- return initializeAndRunImpl(top, initializedAnalysisCount);
+ assert(analysisRoot && "dataflow solver has not been run yet; call "
+ "'initializeAndRun()' first");
+ return initializeAndRunImpl(analysisRoot, initializedAnalysisCount);
}
LogicalResult DataFlowSolver::initializeAndRunImpl(Operation *top,
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 5da3a764f6a9a..57d2d8e08bf11 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -330,7 +330,7 @@ void TestStagedAnalysesPass::runOnOperation() {
if (failed(solver.initializeAndRun(func)))
return signalPassFailure();
solver.load<BarAnalysis>();
- if (failed(solver.initializeAndRunPendingAnalyses(func)))
+ if (failed(solver.initializeAndRunPendingAnalyses()))
return signalPassFailure();
func.walk([&](Operation *op) {
>From 6db41d18ba498f5d97a0f297438fc2b172230066 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Fri, 24 Apr 2026 14:58:16 +0000
Subject: [PATCH 4/4] Add analysis filter to initializeAndRun and TypeID to
DataFlowAnalysis
Replace the staged initializeAndRunPendingAnalyses API with an optional
analysisFilter parameter on initializeAndRun. The filter is a predicate
on DataFlowAnalysis& that controls which analyses are initialized; the
fixpoint loop processes all enqueued work items regardless.
Add a TypeID member to DataFlowAnalysis, set during load(), exposed via
getTypeID(). This enables classof/isa/IsaPred for analysis classes.
---
.../include/mlir/Analysis/DataFlowFramework.h | 77 ++++++-------------
mlir/lib/Analysis/DataFlowFramework.cpp | 48 ++++--------
.../lib/Analysis/TestDataFlowFramework.cpp | 10 ++-
3 files changed, 46 insertions(+), 89 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index c8918203ea4df..23e61f4890232 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -22,6 +22,7 @@
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Compiler.h"
+#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/Support/TypeName.h"
#include <queue>
#include <tuple>
@@ -331,18 +332,17 @@ class DataFlowSolver {
template <typename AnalysisT, typename... Args>
AnalysisT *load(Args &&...args);
- /// Initialize the children analyses from scratch starting from the provided
- /// top-level operation and run the analysis until fixpoint, discarding any
- /// previously computed analysis state.
- LogicalResult initializeAndRun(Operation *top);
-
- /// Initialize any loaded analyses that have not yet been initialized for the
- /// current solver session, then continue running the solver until fixpoint.
+ /// Initialize analyses starting from the provided top-level operation and
+ /// run the analysis until fixpoint.
///
- /// This preserves all previously computed analysis states and is intended for
- /// staged analysis pipelines where later analyses depend on the converged
- /// results of earlier ones.
- LogicalResult initializeAndRunPendingAnalyses();
+ /// An optional \p analysisFilter predicate restricts which analyses are
+ /// initialized. When no filter is given every loaded analysis is
+ /// (re-)initialized. The fixpoint loop always processes all enqueued work
+ /// items regardless of the filter.
+ LogicalResult
+ initializeAndRun(Operation *top,
+ llvm::function_ref<bool(DataFlowAnalysis &)> analysisFilter =
+ nullptr);
/// Lookup an analysis state for the given lattice anchor. Returns null if one
/// does not exist.
@@ -367,10 +367,6 @@ class DataFlowSolver {
void eraseAllStates() {
analysisStates.clear();
equivalentAnchorMap.clear();
- initializedAnalysisCount = 0;
- analysisRoot = nullptr;
- hasFailedRun = false;
- worklist = std::queue<WorkItem>();
}
/// Get a uniqued lattice anchor instance. If one is not present, it is
@@ -445,13 +441,6 @@ class DataFlowSolver {
const DataFlowConfig &getConfig() const { return config; }
private:
- /// Initialize analyses in the range [firstAnalysis, childAnalyses.size())
- /// and continue running the solver until fixpoint.
- LogicalResult initializeAndRunImpl(Operation *top, size_t firstAnalysis);
-
- /// Drain the worklist to a fixpoint.
- LogicalResult runToFixpoint();
-
/// Configuration of the dataflow solver.
DataFlowConfig config;
@@ -463,17 +452,6 @@ class DataFlowSolver {
/// quickly degenerate to quadratic due to propagation of state updates.
std::queue<WorkItem> worklist;
- /// The root operation the current solver session was initialized with.
- Operation *analysisRoot = nullptr;
-
- /// The number of analyses that have been initialized for the current solver
- /// session.
- size_t initializedAnalysisCount = 0;
-
- /// Whether the current solver session has failed and must be reset before
- /// attempting an incremental run.
- bool hasFailedRun = false;
-
/// Type-erased instances of the children analyses.
SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
@@ -661,6 +639,12 @@ class DataFlowAnalysis {
/// necessarily identical under the corrensponding lattice type.
virtual void initializeEquivalentLatticeAnchor(Operation *top) {}
+ /// Return the TypeID of the concrete analysis class. Valid only after
+ /// `DataFlowSolver::load<AnalysisT>` has returned; must not be called from
+ /// the analysis constructor body because the TypeID is set by `load` after
+ /// construction.
+ TypeID getTypeID() const { return analysisTypeID; }
+
protected:
/// Create a dependency between the given analysis state and lattice anchor
/// on this analysis.
@@ -736,6 +720,11 @@ class DataFlowAnalysis {
/// The parent data-flow solver.
DataFlowSolver &solver;
+ /// The TypeID of the concrete analysis class. Set by
+ /// `DataFlowSolver::load` after construction; not available during the
+ /// analysis constructor.
+ TypeID analysisTypeID;
+
/// Allow the data-flow solver to access the internals of this class.
friend class DataFlowSolver;
};
@@ -743,6 +732,7 @@ class DataFlowAnalysis {
template <typename AnalysisT, typename... Args>
AnalysisT *DataFlowSolver::load(Args &&...args) {
childAnalyses.emplace_back(new AnalysisT(*this, std::forward<Args>(args)...));
+ childAnalyses.back()->analysisTypeID = TypeID::get<AnalysisT>();
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
childAnalyses.back()->debugName = llvm::getTypeName<AnalysisT>();
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -795,28 +785,9 @@ bool DataFlowSolver::isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const {
template <typename StateT, typename AnchorT>
void DataFlowSolver::unionLatticeAnchors(AnchorT anchor, AnchorT other) {
- // States are stored on equivalence-class leaders, so canonicalize before
- // checking for materialized states. For example, if `B` is already
- // equivalent to leader `A` and `A` owns the state, a later union(B, C) must
- // observe `A`'s state. Returning early keeps redundant unions like
- // union(A, B) as harmless no-ops.
- LatticeAnchor lhs = getLeaderAnchorOrSelf<StateT>(LatticeAnchor(anchor));
- LatticeAnchor rhs = getLeaderAnchorOrSelf<StateT>(LatticeAnchor(other));
- if (lhs == rhs)
- return;
-
- auto hasStateForAnchor = [&](LatticeAnchor latticeAnchor) {
- auto anchorIt = analysisStates.find(latticeAnchor);
- return anchorIt != analysisStates.end() &&
- anchorIt->second.contains(TypeID::get<StateT>());
- };
- assert(!hasStateForAnchor(lhs) && !hasStateForAnchor(rhs) &&
- "cannot union lattice anchors after analysis states have been "
- "materialized for the state type");
-
llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
equivalentAnchorMap[TypeID::get<StateT>()];
- eqClass.unionSets(lhs, rhs);
+ eqClass.unionSets(LatticeAnchor(anchor), LatticeAnchor(other));
}
inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index bdab7069a186f..dfa26af05cc7b 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -11,7 +11,6 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/iterator.h"
#include "llvm/Config/abi-breaking.h"
@@ -110,26 +109,8 @@ Location LatticeAnchor::getLoc() const {
// DataFlowSolver
//===----------------------------------------------------------------------===//
-LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
- initializedAnalysisCount = 0;
- analysisRoot = top;
- return initializeAndRunImpl(top, /*firstAnalysis=*/0);
-}
-
-LogicalResult DataFlowSolver::initializeAndRunPendingAnalyses() {
- assert(!hasFailedRun && "dataflow solver is in a failed state after a "
- "previous run; call 'initializeAndRun()' to "
- "restart or 'eraseAllStates()' before reusing it");
- assert(analysisRoot && "dataflow solver has not been run yet; call "
- "'initializeAndRun()' first");
- return initializeAndRunImpl(analysisRoot, initializedAnalysisCount);
-}
-
-LogicalResult DataFlowSolver::initializeAndRunImpl(Operation *top,
- size_t firstAnalysis) {
- analysisRoot = top;
- hasFailedRun = true;
-
+LogicalResult DataFlowSolver::initializeAndRun(
+ Operation *top, llvm::function_ref<bool(DataFlowAnalysis &)> analysisFilter) {
// Enable enqueue to the worklist.
isRunning = true;
llvm::scope_exit guard([&]() { isRunning = false; });
@@ -138,31 +119,28 @@ LogicalResult DataFlowSolver::initializeAndRunImpl(Operation *top,
llvm::scope_exit restoreInterprocedural(
[&]() { config.setInterprocedural(isInterprocedural); });
if (isInterprocedural && !top->hasTrait<OpTrait::SymbolTable>())
- config.setInterprocedural(false);
+ config.setInterprocedural(false);
+
+ auto shouldInitialize = [&](DataFlowAnalysis &analysis) {
+ return !analysisFilter || analysisFilter(analysis);
+ };
// Initialize equivalent lattice anchors.
- for (DataFlowAnalysis &analysis : llvm::make_pointee_range(
- llvm::drop_begin(childAnalyses, firstAnalysis))) {
+ for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+ if (!shouldInitialize(analysis))
+ continue;
analysis.initializeEquivalentLatticeAnchor(top);
}
// Initialize the analyses.
- for (DataFlowAnalysis &analysis : llvm::make_pointee_range(
- llvm::drop_begin(childAnalyses, firstAnalysis))) {
+ for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+ if (!shouldInitialize(analysis))
+ continue;
DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
if (failed(analysis.initialize(top)))
return failure();
}
- if (failed(runToFixpoint()))
- return failure();
-
- initializedAnalysisCount = childAnalyses.size();
- hasFailedRun = false;
- return success();
-}
-
-LogicalResult DataFlowSolver::runToFixpoint() {
// Run the analysis until fixpoint.
// Iterate until all states are in some initialized state and the worklist
// is exhausted.
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 57d2d8e08bf11..9af7e205aaee9 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -80,6 +80,10 @@ class FooAnalysis : public DataFlowAnalysis {
using DataFlowAnalysis::DataFlowAnalysis;
+ static bool classof(const DataFlowAnalysis *a) {
+ return a->getTypeID() == TypeID::get<FooAnalysis>();
+ }
+
LogicalResult initialize(Operation *top) override;
LogicalResult visit(ProgramPoint *point) override;
@@ -144,6 +148,10 @@ class BarAnalysis : public DataFlowAnalysis {
using DataFlowAnalysis::DataFlowAnalysis;
+ static bool classof(const DataFlowAnalysis *a) {
+ return a->getTypeID() == TypeID::get<BarAnalysis>();
+ }
+
LogicalResult initialize(Operation *top) override;
LogicalResult visit(ProgramPoint *point) override;
@@ -330,7 +338,7 @@ void TestStagedAnalysesPass::runOnOperation() {
if (failed(solver.initializeAndRun(func)))
return signalPassFailure();
solver.load<BarAnalysis>();
- if (failed(solver.initializeAndRunPendingAnalyses()))
+ if (failed(solver.initializeAndRun(func, llvm::IsaPred<BarAnalysis>)))
return signalPassFailure();
func.walk([&](Operation *op) {
More information about the Mlir-commits
mailing list