[Mlir-commits] [mlir] [mlir] [dataflow] Refactoring the definition of program points in dat… (PR #105656)
donald chen
llvmlistbot at llvm.org
Thu Aug 22 06:18:18 PDT 2024
https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/105656
…a flow analysis
This path distinguishes between program points and lattice anchors in data flow analysis, where lattice anchors represent locations where a lattice can be attached, while program points denote points in program execution.
Related discussions: https://discourse.llvm.org/t/rfc-unify-the-semantics-of-program-points/80671/8
>From 814478cfa43aaed25b8211152d5d7b84ef8b7a0a Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Mon, 19 Aug 2024 12:05:15 +0000
Subject: [PATCH] [mlir] [dataflow] Refactoring the definition of program
points in data flow analysis
This path distinguishes between program points and lattice anchors in data flow
analysis, where lattice anchors represent locations where a lattice can be
attached, while program points denote points in program execution.
Related discussions: https://discourse.llvm.org/t/rfc-unify-the-semantics-of-program-points/80671/8
---
.../mlir/Analysis/DataFlow/DeadCodeAnalysis.h | 16 +-
.../mlir/Analysis/DataFlow/DenseAnalysis.h | 37 +--
.../Analysis/DataFlow/IntegerRangeAnalysis.h | 2 +-
.../mlir/Analysis/DataFlow/SparseAnalysis.h | 8 +-
.../include/mlir/Analysis/DataFlowFramework.h | 248 ++++++++++--------
.../Analysis/DataFlow/DeadCodeAnalysis.cpp | 30 ++-
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp | 32 +--
.../DataFlow/IntegerRangeAnalysis.cpp | 2 +-
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 32 +--
mlir/lib/Analysis/DataFlowFramework.cpp | 45 ++--
.../DataFlow/TestDeadCodeAnalysis.cpp | 2 +-
.../DataFlow/TestDenseDataFlowAnalysis.h | 2 +-
.../lib/Analysis/TestDataFlowFramework.cpp | 7 +-
13 files changed, 250 insertions(+), 213 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 10ef8b6ba5843a..80c8b86c63678a 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -35,21 +35,21 @@ namespace dataflow {
//===----------------------------------------------------------------------===//
/// This is a simple analysis state that represents whether the associated
-/// program point (either a block or a control-flow edge) is live.
+/// lattice anchor (either a block or a control-flow edge) is live.
class Executable : public AnalysisState {
public:
using AnalysisState::AnalysisState;
- /// Set the state of the program point to live.
+ /// Set the state of the lattice anchor to live.
ChangeResult setToLive();
- /// Get whether the program point is live.
+ /// Get whether the lattice anchor is live.
bool isLive() const { return live; }
/// Print the liveness.
void print(raw_ostream &os) const override;
- /// When the state of the program point is changed to live, re-invoke
+ /// When the state of the lattice anchor is changed to live, re-invoke
/// subscribed analyses on the operations in the block and on the block
/// itself.
void onUpdate(DataFlowSolver *solver) const override;
@@ -60,8 +60,8 @@ class Executable : public AnalysisState {
}
private:
- /// Whether the program point is live. Optimistically assume that the program
- /// point is dead.
+ /// Whether the lattice anchor is live. Optimistically assume that the lattice
+ /// anchor is dead.
bool live = false;
/// A set of analyses that should be updated when this state changes.
@@ -140,10 +140,10 @@ class PredecessorState : public AnalysisState {
// CFGEdge
//===----------------------------------------------------------------------===//
-/// This program point represents a control-flow edge between a block and one
+/// This lattice anchor represents a control-flow edge between a block and one
/// of its successors.
class CFGEdge
- : public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
+ : public GenericLatticeAnchorBase<CFGEdge, std::pair<Block *, Block *>> {
public:
using Base::Base;
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 4ad5f3fcd838c0..7917f1e3ba6485 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -91,15 +91,16 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
const AbstractDenseLattice &before,
AbstractDenseLattice *after) = 0;
- /// Get the dense lattice after the execution of the given program point.
- virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
+ /// Get the dense lattice after the execution of the given lattice anchor.
+ virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
/// Get the dense lattice after the execution of the given program point and
- /// add it as a dependency to a program point. That is, every time the lattice
- /// after point is updated, the dependent program point must be visited, and
- /// the newly triggered visit might update the lattice after dependent.
+ /// add it as a dependency to a lattice anchor. That is, every time the
+ /// lattice after anchor is updated, the dependent program point must be
+ /// visited, and the newly triggered visit might update the lattice after
+ /// dependent.
const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
- ProgramPoint point);
+ LatticeAnchor anchor);
/// Set the dense lattice at control flow entry point and propagate an update
/// if it changed.
@@ -249,9 +250,9 @@ class DenseForwardDataFlowAnalysis
}
protected:
- /// Get the dense lattice after this program point.
- LatticeT *getLattice(ProgramPoint point) override {
- return getOrCreate<LatticeT>(point);
+ /// Get the dense lattice on this lattice anchor.
+ LatticeT *getLattice(LatticeAnchor anchor) override {
+ return getOrCreate<LatticeT>(anchor);
}
/// Set the dense lattice at control flow entry point and propagate an update
@@ -331,16 +332,16 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
const AbstractDenseLattice &after,
AbstractDenseLattice *before) = 0;
- /// Get the dense lattice before the execution of the program point. That is,
+ /// Get the dense lattice before the execution of the lattice anchor. That is,
/// before the execution of the given operation or after the execution of the
/// block.
- virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
+ virtual AbstractDenseLattice *getLattice(LatticeAnchor anchor) = 0;
- /// Get the dense lattice before the execution of the program point `point`
- /// and declare that the `dependent` program point must be updated every time
- /// `point` is.
+ /// Get the dense lattice before the execution of the program point in
+ /// `anchor` and declare that the `dependent` program point must be updated
+ /// every time `point` is.
const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
- ProgramPoint point);
+ LatticeAnchor anchor);
/// Set the dense lattice before at the control flow exit point and propagate
/// the update if it changed.
@@ -500,9 +501,9 @@ class DenseBackwardDataFlowAnalysis
}
protected:
- /// Get the dense lattice at the given program point.
- LatticeT *getLattice(ProgramPoint point) override {
- return getOrCreate<LatticeT>(point);
+ /// Get the dense lattice at the given lattice anchor.
+ LatticeT *getLattice(LatticeAnchor anchor) override {
+ return getOrCreate<LatticeT>(anchor);
}
/// Set the dense lattice at control flow exit point (after the terminator)
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index d4a5472cfde868..f99eae379596b6 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -50,7 +50,7 @@ class IntegerRangeAnalysis
/// At an entry point, we cannot reason about interger value ranges.
void setToEntryState(IntegerValueRangeLattice *lattice) override {
propagateIfChanged(lattice, lattice->join(IntegerValueRange::getMaxRange(
- lattice->getPoint())));
+ lattice->getAnchor())));
}
/// Visit an operation. Invoke the transfer function on each operation that
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 89726ae3a855c8..933790b4f2a6eb 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -36,8 +36,8 @@ class AbstractSparseLattice : public AnalysisState {
/// Lattices can only be created for values.
AbstractSparseLattice(Value value) : AnalysisState(value) {}
- /// Return the program point this lattice is located at.
- Value getPoint() const { return AnalysisState::getPoint().get<Value>(); }
+ /// Return the value this lattice is located at.
+ Value getAnchor() const { return AnalysisState::getAnchor().get<Value>(); }
/// Join the information contained in 'rhs' into this lattice. Returns
/// if the value of the lattice changed.
@@ -86,8 +86,8 @@ class Lattice : public AbstractSparseLattice {
public:
using AbstractSparseLattice::AbstractSparseLattice;
- /// Return the program point this lattice is located at.
- Value getPoint() const { return point.get<Value>(); }
+ /// Return the value this lattice is located at.
+ Value getAnchor() const { return anchor.get<Value>(); }
/// Return the value held by this lattice. This requires that the value is
/// initialized.
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 2580ec28b51902..472c5cd5406946 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -31,9 +31,9 @@ namespace mlir {
/// A result type used to indicate if a change happened. Boolean operations on
/// ChangeResult behave as though `Change` is truth.
-enum class [[nodiscard]] ChangeResult {
- NoChange,
- Change,
+enum class [[nodiscard]] ChangeResult{
+ NoChange,
+ Change,
};
inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) {
return lhs == ChangeResult::Change ? lhs : rhs;
@@ -49,79 +49,93 @@ inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
/// Forward declare the analysis state class.
class AnalysisState;
+/// Program point represents a specific location in the execution of a program.
+/// A sequence of program points can be combined into a control flow graph.
+struct ProgramPoint : public PointerUnion<Operation *, Block *> {
+ using ParentTy = PointerUnion<Operation *, Block *>;
+ /// Inherit constructors.
+ using ParentTy::PointerUnion;
+ /// Allow implicit conversion from the parent type.
+ ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
+ /// Allow implicit conversions from operation wrappers.
+ /// TODO: For Windows only. Find a better solution.
+ template <typename OpT, typename = std::enable_if_t<
+ std::is_convertible<OpT, Operation *>::value &&
+ !std::is_same<OpT, Operation *>::value>>
+ ProgramPoint(OpT op) : ParentTy(op) {}
+
+ /// Print the program point.
+ void print(raw_ostream &os) const;
+};
+
//===----------------------------------------------------------------------===//
-// GenericProgramPoint
+// GenericLatticeAnchor
//===----------------------------------------------------------------------===//
-/// Abstract class for generic program points. In classical data-flow analysis,
-/// programs points represent positions in a program to which lattice elements
+/// Abstract class for generic lattice anchor. In classical data-flow analysis,
+/// lattice anchor represent positions in a program to which lattice elements
/// are attached. In sparse data-flow analysis, these can be SSA values, and in
/// dense data-flow analysis, these are the program points before and after
/// every operation.
///
-/// In the general MLIR data-flow analysis framework, program points are an
-/// extensible concept. Program points are uniquely identifiable objects to
-/// which analysis states can be attached. The semantics of program points are
-/// defined by the analyses that specify their transfer functions.
-///
-/// Program points are implemented using MLIR's storage uniquer framework and
+/// Lattice anchor are implemented using MLIR's storage uniquer framework and
/// type ID system to provide RTTI.
-class GenericProgramPoint : public StorageUniquer::BaseStorage {
+class GenericLatticeAnchor : public StorageUniquer::BaseStorage {
public:
- virtual ~GenericProgramPoint();
+ virtual ~GenericLatticeAnchor();
- /// Get the abstract program point's type identifier.
+ /// Get the abstract lattice anchor's type identifier.
TypeID getTypeID() const { return typeID; }
- /// Get a derived source location for the program point.
+ /// Get a derived source location for the lattice anchor.
virtual Location getLoc() const = 0;
- /// Print the program point.
+ /// Print the lattice anchor.
virtual void print(raw_ostream &os) const = 0;
protected:
- /// Create an abstract program point with type identifier.
- explicit GenericProgramPoint(TypeID typeID) : typeID(typeID) {}
+ /// Create an abstract lattice anchor with type identifier.
+ explicit GenericLatticeAnchor(TypeID typeID) : typeID(typeID) {}
private:
- /// The type identifier of the program point.
+ /// The type identifier of the lattice anchor.
TypeID typeID;
};
//===----------------------------------------------------------------------===//
-// GenericProgramPointBase
+// GenericLatticeAnchorBase
//===----------------------------------------------------------------------===//
-/// Base class for generic program points based on a concrete program point
+/// Base class for generic lattice anchor based on a concrete lattice anchor
/// type and a content key. This class defines the common methods required for
/// operability with the storage uniquer framework.
///
-/// The provided key type uniquely identifies the concrete program point
+/// The provided key type uniquely identifies the concrete lattice anchor
/// instance and are the data members of the class.
template <typename ConcreteT, typename Value>
-class GenericProgramPointBase : public GenericProgramPoint {
+class GenericLatticeAnchorBase : public GenericLatticeAnchor {
public:
/// The concrete key type used by the storage uniquer. This class is uniqued
/// by its contents.
using KeyTy = Value;
/// Alias for the base class.
- using Base = GenericProgramPointBase<ConcreteT, Value>;
+ using Base = GenericLatticeAnchorBase<ConcreteT, Value>;
- /// Construct an instance of the program point using the provided value and
+ /// Construct an instance of the lattice anchor using the provided value and
/// the type ID of the concrete type.
template <typename ValueT>
- explicit GenericProgramPointBase(ValueT &&value)
- : GenericProgramPoint(TypeID::get<ConcreteT>()),
+ explicit GenericLatticeAnchorBase(ValueT &&value)
+ : GenericLatticeAnchor(TypeID::get<ConcreteT>()),
value(std::forward<ValueT>(value)) {}
- /// Get a uniqued instance of this program point class with the given
+ /// Get a uniqued instance of this lattice anchor class with the given
/// arguments.
template <typename... Args>
- static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
+ static ConcreteT *get(StorageUniquer &uniquer, Args &&... args) {
return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
}
- /// Allocate space for a program point and construct it in-place.
+ /// Allocate space for a lattice anchor and construct it in-place.
template <typename ValueT>
static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
ValueT &&value) {
@@ -129,46 +143,48 @@ class GenericProgramPointBase : public GenericProgramPoint {
ConcreteT(std::forward<ValueT>(value));
}
- /// Two program points are equal if their values are equal.
+ /// Two lattice anchors are equal if their values are equal.
bool operator==(const Value &value) const { return this->value == value; }
/// Provide LLVM-style RTTI using type IDs.
- static bool classof(const GenericProgramPoint *point) {
+ static bool classof(const GenericLatticeAnchor *point) {
return point->getTypeID() == TypeID::get<ConcreteT>();
}
- /// Get the contents of the program point.
+ /// Get the contents of the lattice anchor.
const Value &getValue() const { return value; }
private:
- /// The program point value.
+ /// The lattice anchor value.
Value value;
};
//===----------------------------------------------------------------------===//
-// ProgramPoint
+// LatticeAnchor
//===----------------------------------------------------------------------===//
-/// Fundamental IR components are supported as first-class program points.
-struct ProgramPoint
- : public PointerUnion<GenericProgramPoint *, Operation *, Value, Block *> {
- using ParentTy =
- PointerUnion<GenericProgramPoint *, Operation *, Value, Block *>;
+/// Fundamental IR components are supported as first-class lattice anchor.
+struct LatticeAnchor
+ : public PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value> {
+ using ParentTy = PointerUnion<GenericLatticeAnchor *, ProgramPoint, Value>;
/// Inherit constructors.
using ParentTy::PointerUnion;
/// Allow implicit conversion from the parent type.
- ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
+ LatticeAnchor(ParentTy point = nullptr) : ParentTy(point) {}
/// Allow implicit conversions from operation wrappers.
/// TODO: For Windows only. Find a better solution.
template <typename OpT, typename = std::enable_if_t<
std::is_convertible<OpT, Operation *>::value &&
!std::is_same<OpT, Operation *>::value>>
- ProgramPoint(OpT op) : ParentTy(op) {}
+ LatticeAnchor(OpT op) : ParentTy(ProgramPoint(op)) {}
- /// Print the program point.
+ LatticeAnchor(Operation *op) : ParentTy(ProgramPoint(op)) {}
+ LatticeAnchor(Block *block) : ParentTy(ProgramPoint(block)) {}
+
+ /// Print the lattice anchor.
void print(raw_ostream &os) const;
- /// Get the source location of the program point.
+ /// Get the source location of the lattice anchor.
Location getLoc() const;
};
@@ -207,8 +223,8 @@ class DataFlowConfig {
/// The general data-flow analysis solver. This class is responsible for
/// orchestrating child data-flow analyses, running the fixed-point iteration
-/// algorithm, managing analysis state and program point memory, and tracking
-/// dependencies between analyses, program points, and analysis states.
+/// algorithm, managing analysis state and lattice anchor memory, and tracking
+/// dependencies between analyses, lattice anchor, and analysis states.
///
/// Steps to run a data-flow analysis:
///
@@ -226,38 +242,39 @@ class DataFlowSolver {
/// Load an analysis into the solver. Return the analysis instance.
template <typename AnalysisT, typename... Args>
- AnalysisT *load(Args &&...args);
+ AnalysisT *load(Args &&... args);
/// Initialize the children analyses starting from the provided top-level
/// operation and run the analysis until fixpoint.
LogicalResult initializeAndRun(Operation *top);
- /// Lookup an analysis state for the given program point. Returns null if one
+ /// Lookup an analysis state for the given lattice anchor. Returns null if one
/// does not exist.
- template <typename StateT, typename PointT>
- const StateT *lookupState(PointT point) const {
- auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
+ template <typename StateT, typename AnchorT>
+ const StateT *lookupState(AnchorT anchor) const {
+ auto it =
+ analysisStates.find({LatticeAnchor(anchor), TypeID::get<StateT>()});
if (it == analysisStates.end())
return nullptr;
return static_cast<const StateT *>(it->second.get());
}
- /// Erase any analysis state associated with the given program point.
- template <typename PointT>
- void eraseState(PointT point) {
- ProgramPoint pp(point);
+ /// Erase any analysis state associated with the given lattice anchor.
+ template <typename AnchorT>
+ void eraseState(AnchorT anchor) {
+ LatticeAnchor la(anchor);
for (auto it = analysisStates.begin(); it != analysisStates.end(); ++it) {
- if (it->first.first == pp)
+ if (it->first.first == la)
analysisStates.erase(it);
}
}
- /// Get a uniqued program point instance. If one is not present, it is
+ /// Get a uniqued lattice anchor instance. If one is not present, it is
/// created with the provided arguments.
- template <typename PointT, typename... Args>
- PointT *getProgramPoint(Args &&...args) {
- return PointT::get(uniquer, std::forward<Args>(args)...);
+ template <typename AnchorT, typename... Args>
+ AnchorT *getLatticeAnchor(Args &&... args) {
+ return AnchorT::get(uniquer, std::forward<Args>(args)...);
}
/// A work item on the solver queue is a program point, child analysis pair.
@@ -267,10 +284,10 @@ class DataFlowSolver {
/// Push a work item onto the worklist.
void enqueue(WorkItem item) { worklist.push(std::move(item)); }
- /// Get the state associated with the given program point. If it does not
+ /// Get the state associated with the given lattice anchor. If it does not
/// exist, create an uninitialized state.
- template <typename StateT, typename PointT>
- StateT *getOrCreateState(PointT point);
+ template <typename StateT, typename AnchorT>
+ StateT *getOrCreateState(AnchorT anchor);
/// Propagate an update to an analysis state if it changed by pushing
/// dependent work items to the back of the queue.
@@ -291,13 +308,13 @@ class DataFlowSolver {
/// Type-erased instances of the children analyses.
SmallVector<std::unique_ptr<DataFlowAnalysis>> childAnalyses;
- /// The storage uniquer instance that owns the memory of the allocated program
- /// points.
+ /// The storage uniquer instance that owns the memory of the allocated lattice
+ /// anchors
StorageUniquer uniquer;
- /// A type-erased map of program points to associated analysis states for
- /// first-class program points.
- DenseMap<std::pair<ProgramPoint, TypeID>, std::unique_ptr<AnalysisState>>
+ /// A type-erased map of lattice anchors to associated analysis states for
+ /// first-class lattice anchors.
+ DenseMap<std::pair<LatticeAnchor, TypeID>, std::unique_ptr<AnalysisState>>
analysisStates;
/// Allow the base child analysis class to access the internals of the solver.
@@ -309,13 +326,13 @@ class DataFlowSolver {
//===----------------------------------------------------------------------===//
/// Base class for generic analysis states. Analysis states contain data-flow
-/// information that are attached to program points and which evolve as the
+/// information that are attached to lattice anchors and which evolve as the
/// analysis iterates.
///
/// This class places no restrictions on the semantics of analysis states beyond
/// these requirements.
///
-/// 1. Querying the state of a program point prior to visiting that point
+/// 1. Querying the state of a lattice anchor prior to visiting that anchor
/// results in uninitialized state. Analyses must be aware of unintialized
/// states.
/// 2. Analysis states can reach fixpoints, where subsequent updates will never
@@ -326,20 +343,20 @@ class AnalysisState {
public:
virtual ~AnalysisState();
- /// Create the analysis state at the given program point.
- AnalysisState(ProgramPoint point) : point(point) {}
+ /// Create the analysis state at the given lattice anchor.
+ AnalysisState(LatticeAnchor anchor) : anchor(anchor) {}
- /// Returns the program point this state is located at.
- ProgramPoint getPoint() const { return point; }
+ /// Returns the lattice anchor this state is located at.
+ LatticeAnchor getAnchor() const { return anchor; }
/// Print the contents of the analysis state.
virtual void print(raw_ostream &os) const = 0;
LLVM_DUMP_METHOD void dump() const;
- /// Add a dependency to this analysis state on a program point and an
+ /// Add a dependency to this analysis state on a lattice anchor and an
/// analysis. If this state is updated, the analysis will be invoked on the
- /// given program point again (in onUpdate()).
- void addDependency(ProgramPoint dependent, DataFlowAnalysis *analysis);
+ /// given lattice anchor again (in onUpdate()).
+ void addDependency(ProgramPoint point, DataFlowAnalysis *analysis);
protected:
/// This function is called by the solver when the analysis state is updated
@@ -351,8 +368,8 @@ class AnalysisState {
solver->enqueue(item);
}
- /// The program point to which the state belongs.
- ProgramPoint point;
+ /// The lattice anchor to which the state belongs.
+ LatticeAnchor anchor;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// When compiling with debugging, keep a name for the analysis state.
@@ -361,8 +378,8 @@ class AnalysisState {
private:
/// The dependency relations originating from this analysis state. An entry
- /// `state -> (analysis, point)` is created when `analysis` queries `state`
- /// when updating `point`.
+ /// `state -> (analysis, anchor)` is created when `analysis` queries `state`
+ /// when updating `anchor`.
///
/// When this state is updated, all dependent child analysis invocations are
/// pushed to the back of the queue. Use a `SetVector` to keep the analysis
@@ -403,7 +420,7 @@ class DataFlowAnalysis {
explicit DataFlowAnalysis(DataFlowSolver &solver);
/// Initialize the analysis from the provided top-level operation by building
- /// an initial dependency graph between all program points of interest. This
+ /// an initial dependency graph between all lattice anchors of interest. This
/// can be implemented by calling `visit` on all program points of interest
/// below the top-level operation.
///
@@ -432,39 +449,39 @@ class DataFlowAnalysis {
virtual LogicalResult visit(ProgramPoint point) = 0;
protected:
- /// Create a dependency between the given analysis state and program point
+ /// Create a dependency between the given analysis state and lattice anchor
/// on this analysis.
void addDependency(AnalysisState *state, ProgramPoint point);
/// Propagate an update to a state if it changed.
void propagateIfChanged(AnalysisState *state, ChangeResult changed);
- /// Register a custom program point class.
- template <typename PointT>
- void registerPointKind() {
- solver.uniquer.registerParametricStorageType<PointT>();
+ /// Register a custom lattice anchor class.
+ template <typename AnchorT>
+ void registerAnchorKind() {
+ solver.uniquer.registerParametricStorageType<AnchorT>();
}
- /// Get or create a custom program point.
- template <typename PointT, typename... Args>
- PointT *getProgramPoint(Args &&...args) {
- return solver.getProgramPoint<PointT>(std::forward<Args>(args)...);
+ /// Get or create a custom lattice anchor.
+ template <typename AnchorT, typename... Args>
+ AnchorT *getLatticeAnchor(Args &&... args) {
+ return solver.getLatticeAnchor<AnchorT>(std::forward<Args>(args)...);
}
- /// Get the analysis state associated with the program point. The returned
+ /// Get the analysis state associated with the lattice anchor. The returned
/// state is expected to be "write-only", and any updates need to be
/// propagated by `propagateIfChanged`.
- template <typename StateT, typename PointT>
- StateT *getOrCreate(PointT point) {
- return solver.getOrCreateState<StateT>(point);
+ template <typename StateT, typename AnchorT>
+ StateT *getOrCreate(AnchorT anchor) {
+ return solver.getOrCreateState<StateT>(anchor);
}
/// Get a read-only analysis state for the given point and create a dependency
/// on `dependent`. If the return state is updated elsewhere, this analysis is
/// re-invoked on the dependent.
- template <typename StateT, typename PointT>
- const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) {
- StateT *state = getOrCreate<StateT>(point);
+ template <typename StateT, typename AnchorT>
+ const StateT *getOrCreateFor(ProgramPoint dependent, AnchorT anchor) {
+ StateT *state = getOrCreate<StateT>(anchor);
addDependency(state, dependent);
return state;
}
@@ -486,7 +503,7 @@ class DataFlowAnalysis {
};
template <typename AnalysisT, typename... Args>
-AnalysisT *DataFlowSolver::load(Args &&...args) {
+AnalysisT *DataFlowSolver::load(Args &&... args) {
childAnalyses.emplace_back(new AnalysisT(*this, std::forward<Args>(args)...));
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
childAnalyses.back().get()->debugName = llvm::getTypeName<AnalysisT>();
@@ -494,12 +511,12 @@ AnalysisT *DataFlowSolver::load(Args &&...args) {
return static_cast<AnalysisT *>(childAnalyses.back().get());
}
-template <typename StateT, typename PointT>
-StateT *DataFlowSolver::getOrCreateState(PointT point) {
+template <typename StateT, typename AnchorT>
+StateT *DataFlowSolver::getOrCreateState(AnchorT anchor) {
std::unique_ptr<AnalysisState> &state =
- analysisStates[{ProgramPoint(point), TypeID::get<StateT>()}];
+ analysisStates[{LatticeAnchor(anchor), TypeID::get<StateT>()}];
if (!state) {
- state = std::unique_ptr<StateT>(new StateT(point));
+ state = std::unique_ptr<StateT>(new StateT(anchor));
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
state->debugName = llvm::getTypeName<StateT>();
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
@@ -512,20 +529,32 @@ inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
return os;
}
-inline raw_ostream &operator<<(raw_ostream &os, ProgramPoint point) {
- point.print(os);
+inline raw_ostream &operator<<(raw_ostream &os, LatticeAnchor anchor) {
+ anchor.print(os);
return os;
}
} // end namespace mlir
namespace llvm {
-/// Allow hashing of program points.
+/// Allow hashing of lattice anchors and program points.
+template <>
+struct DenseMapInfo<mlir::LatticeAnchor>
+ : public DenseMapInfo<mlir::LatticeAnchor::ParentTy> {};
+
template <>
struct DenseMapInfo<mlir::ProgramPoint>
: public DenseMapInfo<mlir::ProgramPoint::ParentTy> {};
// Allow llvm::cast style functions.
+template <typename To>
+struct CastInfo<To, mlir::LatticeAnchor>
+ : public CastInfo<To, mlir::LatticeAnchor::PointerUnion> {};
+
+template <typename To>
+struct CastInfo<To, const mlir::LatticeAnchor>
+ : public CastInfo<To, const mlir::LatticeAnchor::PointerUnion> {};
+
template <typename To>
struct CastInfo<To, mlir::ProgramPoint>
: public CastInfo<To, mlir::ProgramPoint::PointerUnion> {};
@@ -534,6 +563,11 @@ template <typename To>
struct CastInfo<To, const mlir::ProgramPoint>
: public CastInfo<To, const mlir::ProgramPoint::PointerUnion> {};
+/// Allow stealing the low bits of a ProgramPoint.
+template <>
+struct PointerLikeTypeTraits<mlir::ProgramPoint>
+ : public PointerLikeTypeTraits<mlir::ProgramPoint::ParentTy> {};
+
} // end namespace llvm
#endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index fab2bd83888da8..d5e525d61760a9 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -46,17 +46,20 @@ void Executable::print(raw_ostream &os) const {
void Executable::onUpdate(DataFlowSolver *solver) const {
AnalysisState::onUpdate(solver);
- if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
- // Re-invoke the analyses on the block itself.
- for (DataFlowAnalysis *analysis : subscribers)
- solver->enqueue({block, analysis});
- // Re-invoke the analyses on all operations in the block.
- for (DataFlowAnalysis *analysis : subscribers)
- for (Operation &op : *block)
- solver->enqueue({&op, analysis});
- } else if (auto *programPoint = llvm::dyn_cast_if_present<GenericProgramPoint *>(point)) {
+ if (ProgramPoint pp = llvm::dyn_cast_if_present<ProgramPoint>(anchor)) {
+ if (Block *block = llvm::dyn_cast_if_present<Block *>(pp)) {
+ // Re-invoke the analyses on the block itself.
+ for (DataFlowAnalysis *analysis : subscribers)
+ solver->enqueue({block, analysis});
+ // Re-invoke the analyses on all operations in the block.
+ for (DataFlowAnalysis *analysis : subscribers)
+ for (Operation &op : *block)
+ solver->enqueue({&op, analysis});
+ }
+ } else if (auto *latticeAnchor =
+ llvm::dyn_cast_if_present<GenericLatticeAnchor *>(anchor)) {
// Re-invoke the analysis on the successor block.
- if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
+ if (auto *edge = dyn_cast<CFGEdge>(latticeAnchor)) {
for (DataFlowAnalysis *analysis : subscribers)
solver->enqueue({edge->getTo(), analysis});
}
@@ -114,7 +117,7 @@ void CFGEdge::print(raw_ostream &os) const {
DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
: DataFlowAnalysis(solver) {
- registerPointKind<CFGEdge>();
+ registerAnchorKind<CFGEdge>();
}
LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
@@ -218,7 +221,8 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
auto *state = getOrCreate<Executable>(to);
propagateIfChanged(state, state->setToLive());
- auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
+ auto *edgeState =
+ getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(from, to));
propagateIfChanged(edgeState, edgeState->setToLive());
}
@@ -235,8 +239,6 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
if (point.is<Block *>())
return success();
auto *op = llvm::dyn_cast_if_present<Operation *>(point);
- if (!op)
- return emitError(point.getLoc(), "unknown program point kind");
// If the parent block is not executable, there is nothing to do.
if (!getOrCreate<Executable>(op->getBlock())->isLive())
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 33c877f78f4bf6..79df0148fb6591 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -45,12 +45,10 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
}
LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) {
- if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
- return processOperation(op);
- else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
- visitBlock(block);
- else
- return failure();
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(point)) {
+ processOperation(op);
+ }
+ visitBlock(point.get<Block *>());
return success();
}
@@ -180,7 +178,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
// Skip control edges that aren't executable.
Block *predecessor = *it;
if (!getOrCreateFor<Executable>(
- block, getProgramPoint<CFGEdge>(predecessor, block))
+ block, getLatticeAnchor<CFGEdge>(predecessor, block))
->isLive())
continue;
@@ -248,8 +246,8 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
const AbstractDenseLattice *
AbstractDenseForwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
- ProgramPoint point) {
- AbstractDenseLattice *state = getLattice(point);
+ LatticeAnchor anchor) {
+ AbstractDenseLattice *state = getLattice(anchor);
addDependency(state, dependent);
return state;
}
@@ -277,12 +275,10 @@ AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
}
LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
- if (auto *op = llvm::dyn_cast_if_present<Operation *>(point))
- return processOperation(op);
- else if (auto *block = llvm::dyn_cast_if_present<Block *>(point))
- visitBlock(block);
- else
- return failure();
+ if (auto *op = llvm::dyn_cast_if_present<Operation *>(point)) {
+ processOperation(op);
+ }
+ visitBlock(point.get<Block *>());
return success();
}
@@ -424,7 +420,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// Meet the state with the state before block's successors.
for (Block *successor : block->getSuccessors()) {
if (!getOrCreateFor<Executable>(block,
- getProgramPoint<CFGEdge>(block, successor))
+ getLatticeAnchor<CFGEdge>(block, successor))
->isLive())
continue;
@@ -474,8 +470,8 @@ void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
const AbstractDenseLattice *
AbstractDenseBackwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
- ProgramPoint point) {
- AbstractDenseLattice *state = getLattice(point);
+ LatticeAnchor anchor) {
+ AbstractDenseLattice *state = getLattice(anchor);
addDependency(state, dependent);
return state;
}
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 35d38ea02d7162..9a95f172d5df48 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -42,7 +42,7 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
// If the integer range can be narrowed to a constant, update the constant
// value of the SSA value.
std::optional<APInt> constant = getValue().getValue().getConstantValue();
- auto value = point.get<Value>();
+ auto value = anchor.get<Value>();
auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
if (!constant)
return solver->propagateIfChanged(
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index d47d5fec8a9a6a..146ba87c320266 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -34,7 +34,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
AnalysisState::onUpdate(solver);
// Push all users of the value to the queue.
- for (Operation *user : point.get<Value>().getUsers())
+ for (Operation *user : anchor.get<Value>().getUsers())
for (DataFlowAnalysis *analysis : useDefSubscribers)
solver->enqueue({user, analysis});
}
@@ -46,7 +46,7 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
AbstractSparseForwardDataFlowAnalysis::AbstractSparseForwardDataFlowAnalysis(
DataFlowSolver &solver)
: DataFlowAnalysis(solver) {
- registerPointKind<CFGEdge>();
+ registerAnchorKind<CFGEdge>();
}
LogicalResult
@@ -84,12 +84,10 @@ AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
}
LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint point) {
- if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
- return visitOperation(op);
- else if (Block *block = llvm::dyn_cast_if_present<Block *>(point))
- visitBlock(block);
- else
- return failure();
+ if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point)) {
+ visitOperation(op);
+ }
+ visitBlock(point.get<Block *>());
return success();
}
@@ -217,7 +215,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
// If the edge from the predecessor block to the current block is not live,
// bail out.
auto *edgeExecutable =
- getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
+ getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
edgeExecutable->blockContentSubscribe(this);
if (!edgeExecutable->isLive())
continue;
@@ -324,7 +322,7 @@ void AbstractSparseForwardDataFlowAnalysis::join(
AbstractSparseBackwardDataFlowAnalysis::AbstractSparseBackwardDataFlowAnalysis(
DataFlowSolver &solver, SymbolTableCollection &symbolTable)
: DataFlowAnalysis(solver), symbolTable(symbolTable) {
- registerPointKind<CFGEdge>();
+ registerAnchorKind<CFGEdge>();
}
LogicalResult
@@ -354,15 +352,11 @@ AbstractSparseBackwardDataFlowAnalysis::initializeRecursively(Operation *op) {
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
if (Operation *op = llvm::dyn_cast_if_present<Operation *>(point))
- return visitOperation(op);
- else if (llvm::dyn_cast_if_present<Block *>(point))
- // For backward dataflow, we don't have to do any work for the blocks
- // themselves. CFG edges between blocks are processed by the BranchOp
- // logic in `visitOperation`, and entry blocks for functions are tied
- // to the CallOp arguments by visitOperation.
- return success();
- else
- return failure();
+ visitOperation(op);
+ // For backward dataflow, we don't have to do any work for the blocks
+ // themselves. CFG edges between blocks are processed by the BranchOp
+ // logic in `visitOperation`, and entry blocks for functions are tied
+ // to the CallOp arguments by visitOperation.
return success();
}
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index d0e827aa1c2b64..a65ddc13143bae 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -26,10 +26,10 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// GenericProgramPoint
+// GenericLatticeAnchor
//===----------------------------------------------------------------------===//
-GenericProgramPoint::~GenericProgramPoint() = default;
+GenericLatticeAnchor::~GenericLatticeAnchor() = default;
//===----------------------------------------------------------------------===//
// AnalysisState
@@ -44,7 +44,7 @@ void AnalysisState::addDependency(ProgramPoint dependent,
DATAFLOW_DEBUG({
if (inserted) {
llvm::dbgs() << "Creating dependency between " << debugName << " of "
- << point << "\nand " << debugName << " on " << dependent
+ << anchor << "\nand " << debugName << " on " << dependent
<< "\n";
}
});
@@ -53,7 +53,7 @@ void AnalysisState::addDependency(ProgramPoint dependent,
void AnalysisState::dump() const { print(llvm::errs()); }
//===----------------------------------------------------------------------===//
-// ProgramPoint
+// LatticeAnchor
//===----------------------------------------------------------------------===//
void ProgramPoint::print(raw_ostream &os) const {
@@ -61,23 +61,36 @@ void ProgramPoint::print(raw_ostream &os) const {
os << "<NULL POINT>";
return;
}
- if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
- return programPoint->print(os);
- if (auto *op = llvm::dyn_cast<Operation *>(*this))
+ if (Operation *op = llvm::dyn_cast<Operation *>(*this)) {
return op->print(os, OpPrintingFlags().skipRegions());
- if (auto value = llvm::dyn_cast<Value>(*this))
- return value.print(os, OpPrintingFlags().skipRegions());
+ }
return get<Block *>()->print(os);
}
-Location ProgramPoint::getLoc() const {
- if (auto *programPoint = llvm::dyn_cast<GenericProgramPoint *>(*this))
- return programPoint->getLoc();
- if (auto *op = llvm::dyn_cast<Operation *>(*this))
- return op->getLoc();
+void LatticeAnchor::print(raw_ostream &os) const {
+ if (isNull()) {
+ os << "<NULL POINT>";
+ return;
+ }
+ if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
+ return LatticeAnchor->print(os);
+ if (auto value = llvm::dyn_cast<Value>(*this)) {
+ return value.print(os, OpPrintingFlags().skipRegions());
+ }
+
+ return get<ProgramPoint>().print(os);
+}
+
+Location LatticeAnchor::getLoc() const {
+ if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this))
+ return LatticeAnchor->getLoc();
if (auto value = llvm::dyn_cast<Value>(*this))
return value.getLoc();
- return get<Block *>()->getParent()->getLoc();
+
+ ProgramPoint pp = get<ProgramPoint>();
+ if (auto *op = llvm::dyn_cast<Operation *>(pp))
+ return op->getLoc();
+ return pp.get<Block *>()->getParent()->getLoc();
}
//===----------------------------------------------------------------------===//
@@ -117,7 +130,7 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state,
ChangeResult changed) {
if (changed == ChangeResult::Change) {
DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
- << " of " << state->point << "\n"
+ << " of " << state->anchor << "\n"
<< "Value: " << *state << "\n");
state->onUpdate(this);
}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
index 90973af9c2cf5d..d02efaaa3fe320 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
@@ -40,7 +40,7 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
pred->printAsOperand(os);
os << " = ";
auto *live = solver.lookupState<Executable>(
- solver.getProgramPoint<CFGEdge>(pred, &block));
+ solver.getLatticeAnchor<CFGEdge>(pred, &block));
if (live)
os << *live;
else
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
index 57fe0ca458de21..86eb8651cb90c1 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
@@ -206,7 +206,7 @@ class UnderlyingValueAnalysis
/// At an entry point, the underlying value of a value is itself.
void setToEntryState(UnderlyingValueLattice *lattice) override {
propagateIfChanged(lattice,
- lattice->join(UnderlyingValue{lattice->getPoint()}));
+ lattice->join(UnderlyingValue{lattice->getAnchor()}));
}
/// Look for the most underlying value of a value.
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index b6b33182440cf4..13b5c1a6355767 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -119,11 +119,8 @@ LogicalResult FooAnalysis::visit(ProgramPoint point) {
visitOperation(op);
return success();
}
- if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
- visitBlock(block);
- return success();
- }
- return emitError(point.getLoc(), "unknown point kind");
+ visitBlock(point.get<Block *>());
+ return success();
}
void FooAnalysis::visitBlock(Block *block) {
More information about the Mlir-commits
mailing list