[Mlir-commits] [mlir] [mlir] [dataflow] : Improve the time and space footprint of data flow. (PR #135325)
donald chen
llvmlistbot at llvm.org
Fri Apr 11 00:54:23 PDT 2025
https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/135325
MLIR's data flow analysis (especially dense data flow analysis) constructs a lattice at every lattice anchor (which, for dense data flow, means every program point). As the program grows larger, the time and space complexity can become unmanageable.
However, in many programs, the lattice values at numerous lattice anchors are actually identical. We can leverage this observation to improve the complexity of data flow analysis. This patch introducing equivalence lattice anchor to group lattice anchors that must contains identical lattice on certain state to improve the time and space footprint of data flow.
>From 623992fdd71d6a0b5bb5e2c89964b7c66ce752e7 Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Mon, 7 Apr 2025 13:19:16 +0000
Subject: [PATCH] [mlir] [dataflow] : Improve the time and space footprint of
data flow.
MLIR's data flow analysis (especially dense data flow analysis) constructs
a lattice at every lattice anchor (which, for dense data flow, means every
program point). As the program grows larger, the time and space complexity
can become unmanageable.
However, in many programs, the lattice values at numerous lattice anchors
are actually identical. We can leverage this observation to improve the
complexity of data flow analysis. This patch introducing equivalence
lattice anchor to group lattice anchors that must contains identical
lattice on certain state to improve the time and space footprint of data flow.
---
.../mlir/Analysis/DataFlow/DenseAnalysis.h | 26 ++++
.../include/mlir/Analysis/DataFlowFramework.h | 117 ++++++++++++++++--
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp | 18 +++
mlir/lib/Analysis/DataFlowFramework.cpp | 5 +
.../TestDenseBackwardDataFlowAnalysis.cpp | 12 ++
.../TestDenseForwardDataFlowAnalysis.cpp | 13 ++
6 files changed, 182 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 2e32bd1bc1461..68c84089c0856 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -73,6 +73,14 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// may modify the program state; that is, every operation and block.
LogicalResult initialize(Operation *top) override;
+ /// Initialize lattice anchor equivalence class from the provided top-level
+ /// operation.
+ ///
+ /// This function will union lattice anchor to same equivalent class if the
+ /// analysis can determine the lattice content of lattice anchor is
+ /// necessarily identical under the corrensponding lattice type.
+ virtual void initializeEquivalentLatticeAnchor(Operation *top) override;
+
/// Visit a program point that modifies the state of the program. If the
/// program point is at the beginning of a block, then the state is propagated
/// from control-flow predecessors or callsites. If the operation before
@@ -114,6 +122,11 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// operation transfer function.
virtual LogicalResult processOperation(Operation *op);
+ /// Visit an operation. If this analysis can confirm that lattice content
+ /// of lattice anchors around operation are necessarily identical, join
+ /// them into the same equivalent class.
+ virtual void buildOperationEquivalentLatticeAnchor(Operation *op) { return; }
+
/// Propagate the dense lattice forward along the control flow edge from
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
/// values correspond to control flow branches originating at or targeting the
@@ -310,6 +323,14 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// may modify the program state; that is, every operation and block.
LogicalResult initialize(Operation *top) override;
+ /// Initialize lattice anchor equivalence class from the provided top-level
+ /// operation.
+ ///
+ /// This function will union lattice anchor to same equivalent class if the
+ /// analysis can determine the lattice content of lattice anchor is
+ /// necessarily identical under the corrensponding lattice type.
+ virtual void initializeEquivalentLatticeAnchor(Operation *top) override;
+
/// Visit a program point that modifies the state of the program. The state is
/// propagated along control flow directions for branch-, region- and
/// call-based control flow using the respective interfaces. For other
@@ -353,6 +374,11 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// transfer function.
virtual LogicalResult processOperation(Operation *op);
+ /// Visit an operation. If this analysis can confirm that lattice content
+ /// of lattice anchors around operation are necessarily identical, join
+ /// them into the same equivalent class.
+ virtual void buildOperationEquivalentLatticeAnchor(Operation *op) { return; }
+
/// Propagate the dense lattice backwards along the control flow edge from
/// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt`
/// values correspond to control flow branches originating at or targeting the
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 6aa0900d1412a..9b4b41f1c35b2 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -18,6 +18,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/Support/StorageUniquer.h"
+#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Compiler.h"
@@ -265,6 +266,14 @@ struct LatticeAnchor
/// Forward declaration of the data-flow analysis class.
class DataFlowAnalysis;
+} // namespace mlir
+
+template <>
+struct llvm::DenseMapInfo<mlir::LatticeAnchor>
+ : public llvm::DenseMapInfo<mlir::LatticeAnchor::ParentTy> {};
+
+namespace mlir {
+
//===----------------------------------------------------------------------===//
// DataFlowConfig
//===----------------------------------------------------------------------===//
@@ -332,7 +341,9 @@ class DataFlowSolver {
/// does not exist.
template <typename StateT, typename AnchorT>
const StateT *lookupState(AnchorT anchor) const {
- const auto &mapIt = analysisStates.find(LatticeAnchor(anchor));
+ LatticeAnchor latticeAnchor =
+ getLeaderAnchorOrSelf<StateT>(LatticeAnchor(anchor));
+ const auto &mapIt = analysisStates.find(latticeAnchor);
if (mapIt == analysisStates.end())
return nullptr;
auto it = mapIt->second.find(TypeID::get<StateT>());
@@ -344,12 +355,34 @@ class DataFlowSolver {
/// Erase any analysis state associated with the given lattice anchor.
template <typename AnchorT>
void eraseState(AnchorT anchor) {
- LatticeAnchor la(anchor);
- analysisStates.erase(LatticeAnchor(anchor));
+ LatticeAnchor latticeAnchor(anchor);
+
+ // Update equivalentAnchorMap.
+ for (auto &&[TypeId, eqClass] : equivalentAnchorMap) {
+ if (!eqClass.contains(latticeAnchor)) {
+ continue;
+ }
+ llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
+ eqClass.findLeader(latticeAnchor);
+
+ // Update analysis states with new leader if needed.
+ if (*leaderIt == latticeAnchor && ++leaderIt != eqClass.member_end()) {
+ analysisStates[*leaderIt][TypeId] =
+ std::move(analysisStates[latticeAnchor][TypeId]);
+ }
+
+ eqClass.erase(latticeAnchor);
+ }
+
+ // Update analysis states.
+ analysisStates.erase(latticeAnchor);
}
- // Erase all analysis states
- void eraseAllStates() { analysisStates.clear(); }
+ // Erase all analysis states.
+ void eraseAllStates() {
+ analysisStates.clear();
+ equivalentAnchorMap.clear();
+ }
/// Get a uniqued lattice anchor instance. If one is not present, it is
/// created with the provided arguments.
@@ -399,6 +432,19 @@ class DataFlowSolver {
template <typename StateT, typename AnchorT>
StateT *getOrCreateState(AnchorT anchor);
+ /// Get leader lattice anchor in equivalence lattice anchor group, return
+ /// input lattice anchor if input not found in equivalece lattice anchor group.
+ template <typename StateT>
+ LatticeAnchor getLeaderAnchorOrSelf(LatticeAnchor latticeAnchor) const;
+
+ /// Union input anchors under the given state.
+ template <typename StateT, typename AnchorT>
+ void unionLatticeAnchors(AnchorT anchor, AnchorT other);
+
+ /// Return given lattice is equivalent on given state.
+ template <typename StateT>
+ bool isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const;
+
/// Propagate an update to an analysis state if it changed by pushing
/// dependent work items to the back of the queue.
/// This should only be used when DataFlowSolver is running.
@@ -429,10 +475,15 @@ class DataFlowSolver {
/// A type-erased map of lattice anchors to associated analysis states for
/// first-class lattice anchors.
- DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>,
- DenseMapInfo<LatticeAnchor::ParentTy>>
+ DenseMap<LatticeAnchor, DenseMap<TypeID, std::unique_ptr<AnalysisState>>>
analysisStates;
+ /// A type-erased map of lattice type to the equivalet lattice anchors.
+ /// Lattice anchors are considered equivalent under a certain lattice type if
+ /// and only if, under this lattice type, the lattices pointed to by these
+ /// lattice anchors necessarily contain identical value.
+ DenseMap<TypeID, llvm::EquivalenceClasses<LatticeAnchor>> equivalentAnchorMap;
+
/// Allow the base child analysis class to access the internals of the solver.
friend class DataFlowAnalysis;
};
@@ -564,6 +615,14 @@ class DataFlowAnalysis {
/// will provide a value for then.
virtual LogicalResult visit(ProgramPoint *point) = 0;
+ /// Initialize lattice anchor equivalence class from the provided top-level
+ /// operation.
+ ///
+ /// This function will union lattice anchor to same equivalent class if the
+ /// analysis can determine the lattice content of lattice anchor is
+ /// necessarily identical under the corrensponding lattice type.
+ virtual void initializeEquivalentLatticeAnchor(Operation *top) { return; }
+
protected:
/// Create a dependency between the given analysis state and lattice anchor
/// on this analysis.
@@ -584,6 +643,12 @@ class DataFlowAnalysis {
return solver.getLatticeAnchor<AnchorT>(std::forward<Args>(args)...);
}
+ /// Union input anchors under the given state.
+ template <typename StateT, typename AnchorT>
+ void unionLatticeAnchors(AnchorT anchor, AnchorT other) {
+ return solver.unionLatticeAnchors<StateT>(anchor, other);
+ }
+
/// Get the analysis state associated with the lattice anchor. The returned
/// state is expected to be "write-only", and any updates need to be
/// propagated by `propagateIfChanged`.
@@ -598,7 +663,9 @@ class DataFlowAnalysis {
template <typename StateT, typename AnchorT>
const StateT *getOrCreateFor(ProgramPoint *dependent, AnchorT anchor) {
StateT *state = getOrCreate<StateT>(anchor);
- addDependency(state, dependent);
+ if (!solver.isEquivalent<StateT>(LatticeAnchor(anchor),
+ LatticeAnchor(dependent)))
+ addDependency(state, dependent);
return state;
}
@@ -644,10 +711,26 @@ AnalysisT *DataFlowSolver::load(Args &&...args) {
return static_cast<AnalysisT *>(childAnalyses.back().get());
}
+template <typename StateT>
+LatticeAnchor
+DataFlowSolver::getLeaderAnchorOrSelf(LatticeAnchor latticeAnchor) const {
+ const llvm::EquivalenceClasses<LatticeAnchor> eqClass =
+ equivalentAnchorMap.lookup(TypeID::get<StateT>());
+ llvm::EquivalenceClasses<LatticeAnchor>::member_iterator leaderIt =
+ eqClass.findLeader(latticeAnchor);
+ if (leaderIt != eqClass.member_end()) {
+ return *leaderIt;
+ }
+ return latticeAnchor;
+}
+
template <typename StateT, typename AnchorT>
StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
+ // Replace to leader anchor if found.
+ LatticeAnchor latticeAnchor(anchor);
+ latticeAnchor = getLeaderAnchorOrSelf<StateT>(latticeAnchor);
std::unique_ptr<AnalysisState> &state =
- analysisStates[LatticeAnchor(anchor)][TypeID::get<StateT>()];
+ analysisStates[latticeAnchor][TypeID::get<StateT>()];
if (!state) {
state = std::unique_ptr<StateT>(new StateT(anchor));
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -657,6 +740,22 @@ StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
return static_cast<StateT *>(state.get());
}
+template <typename StateT>
+bool DataFlowSolver::isEquivalent(LatticeAnchor lhs, LatticeAnchor rhs) const {
+ const llvm::EquivalenceClasses<LatticeAnchor> eqClass =
+ equivalentAnchorMap.lookup(TypeID::get<StateT>());
+ if (!eqClass.contains(lhs) || !eqClass.contains(rhs))
+ return false;
+ return eqClass.isEquivalent(lhs, rhs);
+}
+
+template <typename StateT, typename AnchorT>
+void DataFlowSolver::unionLatticeAnchors(AnchorT anchor, AnchorT other) {
+ llvm::EquivalenceClasses<LatticeAnchor> &eqClass =
+ equivalentAnchorMap[TypeID::get<StateT>()];
+ eqClass.unionSets(LatticeAnchor(anchor), LatticeAnchor(other));
+}
+
inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
state.print(os);
return os;
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 340aa399ec12e..7ec5b58425e91 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -28,6 +28,15 @@ using namespace mlir::dataflow;
// AbstractDenseForwardDataFlowAnalysis
//===----------------------------------------------------------------------===//
+void AbstractDenseForwardDataFlowAnalysis::initializeEquivalentLatticeAnchor(
+ Operation *top) {
+ top->walk([&](Operation *op) {
+ if (isa<RegionBranchOpInterface, CallOpInterface>(op))
+ return;
+ buildOperationEquivalentLatticeAnchor(op);
+ });
+}
+
LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
// Visit every operation and block.
if (failed(processOperation(top)))
@@ -252,6 +261,15 @@ AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint *dependent,
// AbstractDenseBackwardDataFlowAnalysis
//===----------------------------------------------------------------------===//
+void AbstractDenseBackwardDataFlowAnalysis::initializeEquivalentLatticeAnchor(
+ Operation *top) {
+ top->walk([&](Operation *op) {
+ if (isa<RegionBranchOpInterface, CallOpInterface>(op))
+ return;
+ buildOperationEquivalentLatticeAnchor(op);
+ });
+}
+
LogicalResult
AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
// Visit every operation and block.
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 29f57c602f9cb..176d53e017c9f 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -109,6 +109,11 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
isRunning = true;
auto guard = llvm::make_scope_exit([&]() { isRunning = false; });
+ // Initialize equivalent lattice anchors.
+ for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
+ analysis.initializeEquivalentLatticeAnchor(top);
+ }
+
// Initialize the analyses.
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
DATAFLOW_DEBUG(llvm::dbgs()
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index fa6223aa9168b..da543f4f04f97 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -76,6 +76,11 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
propagateIfChanged(lattice, lattice->setKnownToUnknown());
}
+ /// Visit an operation. If this analysis can confirm that lattice content
+ /// of lattice anchors around operation are necessarily identical, join
+ /// them into the same equivalent class.
+ void buildOperationEquivalentLatticeAnchor(Operation *op) override;
+
const bool assumeFuncReads;
};
} // namespace
@@ -141,6 +146,13 @@ LogicalResult NextAccessAnalysis::visitOperation(Operation *op,
return success();
}
+void NextAccessAnalysis::buildOperationEquivalentLatticeAnchor(Operation *op) {
+ if (isMemoryEffectFree(op)) {
+ unionLatticeAnchors<NextAccess>(getProgramPointBefore(op),
+ getProgramPointAfter(op));
+ }
+}
+
void NextAccessAnalysis::visitCallControlFlowTransfer(
CallOpInterface call, CallControlFlowAction action, const NextAccess &after,
NextAccess *before) {
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
index 89b5c835744fd..f4f8e9115a3fa 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
@@ -72,6 +72,11 @@ class LastModifiedAnalysis
const LastModification &before,
LastModification *after) override;
+ /// Visit an operation. If this analysis can confirm that lattice content
+ /// of lattice anchors around operation are necessarily identical, join
+ /// them into the same equivalent class.
+ void buildOperationEquivalentLatticeAnchor(Operation *op) override;
+
/// At an entry point, the last modifications of all memory resources are
/// unknown.
void setToEntryState(LastModification *lattice) override {
@@ -147,6 +152,14 @@ LogicalResult LastModifiedAnalysis::visitOperation(
return success();
}
+void LastModifiedAnalysis::buildOperationEquivalentLatticeAnchor(
+ Operation *op) {
+ if (isMemoryEffectFree(op)) {
+ unionLatticeAnchors<LastModification>(getProgramPointBefore(op),
+ getProgramPointAfter(op));
+ }
+}
+
void LastModifiedAnalysis::visitCallControlFlowTransfer(
CallOpInterface call, CallControlFlowAction action,
const LastModification &before, LastModification *after) {
More information about the Mlir-commits
mailing list