[Mlir-commits] [mlir] 6a66673 - [mlir][dataflow] Unify dependency management in AnalysisState.
Jeff Niu
llvmlistbot at llvm.org
Mon Jul 3 12:20:58 PDT 2023
Author: Zhixun Tan
Date: 2023-07-03T12:20:52-07:00
New Revision: 6a66673765b2bf45f412ab4261a72704805dd526
URL: https://github.com/llvm/llvm-project/commit/6a66673765b2bf45f412ab4261a72704805dd526
DIFF: https://github.com/llvm/llvm-project/commit/6a66673765b2bf45f412ab4261a72704805dd526.diff
LOG: [mlir][dataflow] Unify dependency management in AnalysisState.
In the MLIR dataflow analysis framework, when an `AnalysisState` is updated, it's dependents are enqueued to be visited.
Currently, there are two ways dependents are managed:
* `AnalysisState::dependents` stores a list of dependents. `DataFlowSolver::propagateIfChanged()` reads this list and enqueues them to the worklist.
* `AnalysisState::onUpdate()` allows custom logic to enqueue more to the worklist. This is called by `DataFlowSolver::propagateIfChanged()`.
This cleanup diff consolidates the two into `AnalysisState::onUpdate()`. This way, `DataFlowSolver` does not need to know the detail about `AnalysisState::dependents`, and the logic of dependency management is entirely handled by `AnalysisState`.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D154170
Added:
Modified:
mlir/include/mlir/Analysis/DataFlowFramework.h
mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Analysis/DataFlowFramework.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 9649f918faa2f3..7b97ea4a147bdf 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -235,12 +235,6 @@ class DataFlowSolver {
/// dependent work items to the back of the queue.
void propagateIfChanged(AnalysisState *state, ChangeResult changed);
- /// Add a dependency to an analysis state on a child analysis and program
- /// point. If the state is updated, the child analysis must be invoked on the
- /// given program point again.
- void addDependency(AnalysisState *state, DataFlowAnalysis *analysis,
- ProgramPoint point);
-
private:
/// The solver's work queue. Work items can be inserted to the front of the
/// queue to be processed greedily, speeding up computations that otherwise
@@ -294,13 +288,30 @@ class AnalysisState {
/// Print the contents of the analysis state.
virtual void print(raw_ostream &os) const = 0;
+ /// Add a dependency to this analysis state on a program point and an
+ /// analysis. If this state is updated, the analysis will be invoked on the
+ /// given program point again (in onUpdate()).
+ void addDependency(ProgramPoint dependent, DataFlowAnalysis *analysis);
+
protected:
/// This function is called by the solver when the analysis state is updated
- /// to optionally enqueue more work items. For example, if a state tracks
- /// dependents through the IR (e.g. use-def chains), this function can be
- /// implemented to push those dependents on the worklist.
- virtual void onUpdate(DataFlowSolver *solver) const {}
+ /// to enqueue more work items. For example, if a state tracks dependents
+ /// through the IR (e.g. use-def chains), this function can be implemented to
+ /// push those dependents on the worklist.
+ virtual void onUpdate(DataFlowSolver *solver) const {
+ for (const DataFlowSolver::WorkItem &item : dependents)
+ solver->enqueue(item);
+ }
+
+ /// The program point to which the state belongs.
+ ProgramPoint point;
+
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+ /// When compiling with debugging, keep a name for the analysis state.
+ StringRef debugName;
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+private:
/// The dependency relations originating from this analysis state. An entry
/// `state -> (analysis, point)` is created when `analysis` queries `state`
/// when updating `point`.
@@ -312,14 +323,6 @@ class AnalysisState {
/// Store the dependents on the analysis state for efficiency.
SetVector<DataFlowSolver::WorkItem> dependents;
- /// The program point to which the state belongs.
- ProgramPoint point;
-
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- /// When compiling with debugging, keep a name for the analysis state.
- StringRef debugName;
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-
/// Allow the framework to access the dependents.
friend class DataFlowSolver;
};
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index d681604aaff64f..30a285068a0748 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -8,6 +8,7 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include <optional>
@@ -31,6 +32,8 @@ void Executable::print(raw_ostream &os) const {
}
void Executable::onUpdate(DataFlowSolver *solver) const {
+ AnalysisState::onUpdate(solver);
+
if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
// Re-invoke the analyses on the block itself.
for (DataFlowAnalysis *analysis : subscribers)
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index f5cf866d0d2a50..3f2a69e0ed6505 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -8,6 +8,7 @@
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Interfaces/CallInterfaces.h"
using namespace mlir;
@@ -18,6 +19,8 @@ using namespace mlir::dataflow;
//===----------------------------------------------------------------------===//
void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
+ AnalysisState::onUpdate(solver);
+
// Push all users of the value to the queue.
for (Operation *user : point.get<Value>().getUsers())
for (DataFlowAnalysis *analysis : useDefSubscribers)
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 47caf268290ad7..6f9168c4ebcda5 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -30,6 +30,19 @@ GenericProgramPoint::~GenericProgramPoint() = default;
AnalysisState::~AnalysisState() = default;
+void AnalysisState::addDependency(ProgramPoint dependent,
+ DataFlowAnalysis *analysis) {
+ auto inserted = dependents.insert({dependent, analysis});
+ (void)inserted;
+ DATAFLOW_DEBUG({
+ if (inserted) {
+ llvm::dbgs() << "Creating dependency between " << debugName << " of "
+ << point << "\nand " << debugName << " on " << dependent
+ << "\n";
+ }
+ });
+}
+
//===----------------------------------------------------------------------===//
// ProgramPoint
//===----------------------------------------------------------------------===//
@@ -97,26 +110,10 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state,
DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
<< " of " << state->point << "\n"
<< "Value: " << *state << "\n");
- for (const WorkItem &item : state->dependents)
- enqueue(item);
state->onUpdate(this);
}
}
-void DataFlowSolver::addDependency(AnalysisState *state,
- DataFlowAnalysis *analysis,
- ProgramPoint point) {
- auto inserted = state->dependents.insert({point, analysis});
- (void)inserted;
- DATAFLOW_DEBUG({
- if (inserted) {
- llvm::dbgs() << "Creating dependency between " << state->debugName
- << " of " << state->point << "\nand " << analysis->debugName
- << " on " << point << "\n";
- }
- });
-}
-
//===----------------------------------------------------------------------===//
// DataFlowAnalysis
//===----------------------------------------------------------------------===//
@@ -126,7 +123,7 @@ DataFlowAnalysis::~DataFlowAnalysis() = default;
DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) {
- solver.addDependency(state, this, point);
+ state->addDependency(point, this);
}
void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,
More information about the Mlir-commits
mailing list