[llvm-branch-commits] [mlir] 1ed1e8c - overhaul state management and allow multi state elements
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jun 29 17:03:51 PDT 2022
Author: Mogball
Date: 2022-06-29T17:03:09-07:00
New Revision: 1ed1e8c7843248091019d5e902b73405567dc464
URL: https://github.com/llvm/llvm-project/commit/1ed1e8c7843248091019d5e902b73405567dc464
DIFF: https://github.com/llvm/llvm-project/commit/1ed1e8c7843248091019d5e902b73405567dc464.diff
LOG: overhaul state management and allow multi state elements
Added:
Modified:
mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/include/mlir/Analysis/DataFlowFramework.h
mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Analysis/DataFlowFramework.cpp
mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
mlir/lib/Transforms/SCCP.cpp
mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
mlir/test/lib/Analysis/TestDataFlowFramework.cpp
mlir/test/lib/Transforms/TestIntRangeInference.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
index 8e4baea08ab1..518e5ac23b39 100644
--- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
@@ -29,7 +29,7 @@ namespace dataflow {
class ConstantValue {
public:
/// The pessimistic value state of the constant value is unknown.
- static ConstantValue getPessimisticValueState(Value value) { return {}; }
+ static ConstantValue getPessimisticValue(Value value) { return {}; }
/// Construct a constant value with a known constant.
ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr)
@@ -53,6 +53,14 @@ class ConstantValue {
return lhs == rhs ? lhs : ConstantValue();
}
+ static ConstantValue meet(const ConstantValue &lhs,
+ const ConstantValue &rhs) {
+ if (lhs == rhs) return lhs;
+ if (!lhs.constant) return rhs;
+ if (!rhs.constant) return lhs;
+ return ConstantValue();
+ }
+
/// Print the constant value.
void print(raw_ostream &os) const;
@@ -63,6 +71,13 @@ class ConstantValue {
Dialect *dialect;
};
+class ConstantValueState : public OptimisticSparseState<ConstantValue> {
+public:
+ using OptimisticSparseState::OptimisticSparseState;
+ using ElementT =
+ SparseElement<ConstantValueState, MultiStateElement>;
+};
+
//===----------------------------------------------------------------------===//
// SparseConstantPropagation
//===----------------------------------------------------------------------===//
@@ -72,13 +87,13 @@ class ConstantValue {
/// operands, by speculatively folding operations. When combined with dead-code
/// analysis, this becomes sparse conditional constant propagation (SCCP).
class SparseConstantPropagation
- : public SparseDataFlowAnalysis<Lattice<ConstantValue>> {
+ : public SparseDataFlowAnalysis<ConstantValueState> {
public:
using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
- void visitOperation(Operation *op,
- ArrayRef<const Lattice<ConstantValue> *> operands,
- ArrayRef<Lattice<ConstantValue> *> results) override;
+ void
+ visitOperation(Operation *op, ArrayRef<const ConstantValueState *> operands,
+ ArrayRef<ConstantValueState::ElementT *> results) override;
};
} // end namespace dataflow
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 615c7eece33a..6c8dca49af55 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -24,21 +24,77 @@
namespace mlir {
namespace dataflow {
+//===----------------------------------------------------------------------===//
+// CFGEdge
+//===----------------------------------------------------------------------===//
+
+/// This program point represents a control-flow edge between a block and one
+/// of its successors.
+class CFGEdge
+ : public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
+public:
+ using Base::Base;
+
+ /// Get the block from which the edge originates.
+ Block *getFrom() const { return getValue().first; }
+ /// Get the target block.
+ Block *getTo() const { return getValue().second; }
+
+ /// Print the blocks between the control-flow edge.
+ void print(raw_ostream &os) const override;
+ /// Get a fused location of both blocks.
+ Location getLoc() const override;
+};
+
//===----------------------------------------------------------------------===//
// Executable
//===----------------------------------------------------------------------===//
/// This is a simple analysis state that represents whether the associated
/// program point (either a block or a control-flow edge) is live.
-class Executable : public AnalysisState {
+class Executable : public AbstractState {
public:
- using AnalysisState::AnalysisState;
-
- /// The state is initialized by default.
- bool isUninitialized() const override { return false; }
-
- /// The state is always initialized.
- ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
+ template <typename ExecutableT>
+ class Element : public SingleStateElement<ExecutableT> {
+ public:
+ using SingleStateElement<ExecutableT>::SingleStateElement;
+
+ /// When the state of the program point is changed to live, re-invoke
+ /// subscribed analyses on the operations in the block and on the block
+ /// itself.
+ void onUpdate() override {
+ if (auto *block = this->point.template dyn_cast<Block *>()) {
+ // Re-invoke the analyses on the block itself.
+ for (DataFlowAnalysis *analysis : subscribers)
+ this->solver.enqueue({block, analysis});
+ // Re-invoke the analyses on all operations in the block.
+ for (DataFlowAnalysis *analysis : subscribers)
+ for (Operation &op : *block)
+ this->solver.enqueue({&op, analysis});
+ } else if (auto *programPoint =
+ this->point.template dyn_cast<GenericProgramPoint *>()) {
+ // Re-invoke the analysis on the successor block.
+ if (auto *edge = dyn_cast<CFGEdge>(programPoint))
+ for (DataFlowAnalysis *analysis : subscribers)
+ this->solver.enqueue({edge->getTo(), analysis});
+ }
+ }
+
+ /// Subscribe an analysis to changes to the liveness.
+ void blockContentSubscribe(DataFlowAnalysis *analysis) {
+ subscribers.insert(analysis);
+ }
+
+ private:
+ /// A set of analyses that should be updated when this state changes.
+ SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
+ SmallPtrSet<DataFlowAnalysis *, 4>>
+ subscribers;
+ };
+ using ElementT = Element<Executable>;
+
+ /// Optimistically assume the program point is dead.
+ explicit Executable(ProgramPoint point) : live(false) {}
/// Set the state of the program point to live.
ChangeResult setToLive();
@@ -46,28 +102,12 @@ class Executable : public AnalysisState {
/// Get whether the program point is live.
bool isLive() const { return live; }
- /// Print the liveness.
+ /// Print the liveness;
void print(raw_ostream &os) const override;
- /// When the state of the program point is changed to live, re-invoke
- /// subscribed analyses on the operations in the block and on the block
- /// itself.
- void onUpdate(DataFlowSolver *solver) const override;
-
- /// Subscribe an analysis to changes to the liveness.
- void blockContentSubscribe(DataFlowAnalysis *analysis) {
- subscribers.insert(analysis);
- }
-
private:
- /// Whether the program point is live. Optimistically assume that the program
- /// point is dead.
- bool live = false;
-
- /// A set of analyses that should be updated when this state changes.
- SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
- SmallPtrSet<DataFlowAnalysis *, 4>>
- subscribers;
+ /// Whether the program point is live.
+ bool live;
};
//===----------------------------------------------------------------------===//
@@ -90,15 +130,11 @@ class Executable : public AnalysisState {
///
/// The state can indicate that it is underdefined, meaning that not all live
/// control-flow predecessors can be known.
-class PredecessorState : public AnalysisState {
+class PredecessorState : public AbstractState {
public:
- using AnalysisState::AnalysisState;
+ using ElementT = SingleStateElement<PredecessorState>;
- /// The state is initialized by default.
- bool isUninitialized() const override { return false; }
-
- /// The state is always initialized.
- ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
+ explicit PredecessorState(ProgramPoint point) {}
/// Print the known predecessors.
void print(raw_ostream &os) const override;
@@ -142,28 +178,6 @@ class PredecessorState : public AnalysisState {
DenseMap<Operation *, ValueRange> successorInputs;
};
-//===----------------------------------------------------------------------===//
-// CFGEdge
-//===----------------------------------------------------------------------===//
-
-/// This program point represents a control-flow edge between a block and one
-/// of its successors.
-class CFGEdge
- : public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
-public:
- using Base::Base;
-
- /// Get the block from which the edge originates.
- Block *getFrom() const { return getValue().first; }
- /// Get the target block.
- Block *getTo() const { return getValue().second; }
-
- /// Print the blocks between the control-flow edge.
- void print(raw_ostream &os) const override;
- /// Get a fused location of both blocks.
- Location getLoc() const override;
-};
-
//===----------------------------------------------------------------------===//
// DeadCodeAnalysis
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 01b77d958519..96697307a49d 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -16,33 +16,26 @@
#define MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
namespace mlir {
namespace dataflow {
//===----------------------------------------------------------------------===//
-// AbstractDenseLattice
+// AbstractDenseState
//===----------------------------------------------------------------------===//
/// This class represents a dense lattice. A dense lattice is attached to
/// operations to represent the program state after their execution or to blocks
/// to represent the program state at the beginning of the block. A dense
/// lattice is propagated through the IR by dense data-flow analysis.
-class AbstractDenseLattice : public AnalysisState {
-public:
- /// A dense lattice can only be created for operations and blocks.
- using AnalysisState::AnalysisState;
-
- /// Join the lattice across control-flow or callgraph edges.
- virtual ChangeResult join(const AbstractDenseLattice &rhs) = 0;
+using AbstractDenseState = AbstractSparseState;
- /// Reset the dense lattice to a pessimistic value. This occurs when the
- /// analysis cannot reason about the data-flow.
- virtual ChangeResult reset() = 0;
+class AbstractDenseElement : public AbstractElement {
+public:
+ using AbstractElement::AbstractElement;
- /// Returns true if the lattice state has reached a pessimistic fixpoint. That
- /// is, no further modifications to the lattice can occur.
- virtual bool isAtFixpoint() const = 0;
+ virtual const AbstractDenseState *get() const override = 0;
};
//===----------------------------------------------------------------------===//
@@ -78,26 +71,29 @@ class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
/// Propagate the dense lattice before the execution of an operation to the
/// lattice after its execution.
virtual void visitOperationImpl(Operation *op,
- const AbstractDenseLattice &before,
- AbstractDenseLattice *after) = 0;
+ const AbstractDenseState &before,
+ AbstractDenseElement *after) = 0;
- /// Get the dense lattice after the execution of the given program point.
- virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
+ /// Get the dense element after the execution of the given program point.
+ virtual AbstractDenseElement *getLattice(ProgramPoint point) = 0;
/// Get the dense lattice after the execution of the given program point and
/// add it as a dependency to a program point.
- const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
- ProgramPoint point);
-
- /// Mark the dense lattice as having reached its pessimistic fixpoint and
- /// propagate an update if it changed.
- void reset(AbstractDenseLattice *lattice) {
- propagateIfChanged(lattice, lattice->reset());
+ const AbstractDenseState *getLatticeFor(ProgramPoint dependent,
+ ProgramPoint point);
+
+ void update(AbstractDenseElement *element,
+ function_ref<ChangeResult(AbstractDenseState *)> updateFn) {
+ element->update(this, [updateFn](AbstractState *state) {
+ return updateFn(static_cast<AbstractDenseState *>(state));
+ });
}
- /// Join a lattice with another and propagate an update if it changed.
- void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs) {
- propagateIfChanged(lhs, lhs->join(rhs));
+ void markPessimisticFixpoint(AbstractDenseElement *element) {
+ element->update(this, [](AbstractState *state) {
+ return static_cast<AbstractDenseState *>(state)
+ ->markPessimisticFixpoint();
+ });
}
private:
@@ -116,7 +112,7 @@ class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
/// parent operation itself.
void visitRegionBranchOperation(ProgramPoint point,
RegionBranchOpInterface branch,
- AbstractDenseLattice *after);
+ AbstractDenseElement *after);
};
//===----------------------------------------------------------------------===//
@@ -128,29 +124,29 @@ class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
/// transfer functions for operations.
///
/// `StateT` is expected to be a subclass of `AbstractDenseLattice`.
-template <typename LatticeT>
+template <typename StateT>
class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis {
public:
using AbstractDenseDataFlowAnalysis::AbstractDenseDataFlowAnalysis;
/// Visit an operation with the dense lattice before its execution. This
/// function is expected to set the dense lattice after its execution.
- virtual void visitOperation(Operation *op, const LatticeT &before,
- LatticeT *after) = 0;
+ virtual void visitOperation(Operation *op, const StateT &before,
+ typename StateT::ElementT *after) = 0;
protected:
/// Get the dense lattice after this program point.
- LatticeT *getLattice(ProgramPoint point) override {
- return getOrCreate<LatticeT>(point);
+ typename StateT::ElementT *getLattice(ProgramPoint point) override {
+ return getOrCreate<StateT>(point);
}
private:
/// Type-erased wrappers that convert the abstract dense lattice to a derived
/// lattice and invoke the virtual hooks operating on the derived lattice.
- void visitOperationImpl(Operation *op, const AbstractDenseLattice &before,
- AbstractDenseLattice *after) override {
- visitOperation(op, static_cast<const LatticeT &>(before),
- static_cast<LatticeT *>(after));
+ void visitOperationImpl(Operation *op, const AbstractDenseState &before,
+ AbstractDenseElement *after) override {
+ visitOperation(op, static_cast<const StateT &>(before),
+ static_cast<typename StateT::ElementT *>(after));
}
};
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 3cd007ab478b..2473bf27f55f 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -16,6 +16,7 @@
#define MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
namespace mlir {
@@ -27,7 +28,7 @@ class IntegerValueRange {
/// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
/// range that is used to mark the value as unable to be analyzed further,
/// where `t` is the type of `value`.
- static IntegerValueRange getPessimisticValueState(Value value);
+ static IntegerValueRange getPessimisticValue(Value value);
/// Create an integer value range lattice value.
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
@@ -45,6 +46,10 @@ class IntegerValueRange {
const IntegerValueRange &rhs) {
return lhs.value.rangeUnion(rhs.value);
}
+ static IntegerValueRange meet(const IntegerValueRange &lhs,
+ const IntegerValueRange &rhs) {
+ return lhs.value.intersection(rhs.value);
+ }
/// Print the integer value range.
void print(raw_ostream &os) const { os << value; }
@@ -57,38 +62,49 @@ class IntegerValueRange {
/// This lattice element represents the integer value range of an SSA value.
/// When this lattice is updated, it automatically updates the constant value
/// of the SSA value (if the range can be narrowed to one).
-class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
+class IntegerValueRangeState : public OptimisticSparseState<IntegerValueRange> {
public:
- using Lattice::Lattice;
-
- /// If the range can be narrowed to an integer constant, update the constant
- /// value of the SSA value.
- void onUpdate(DataFlowSolver *solver) const override;
+ using OptimisticSparseState::OptimisticSparseState;
+ using ElementT =
+ SparseElement<IntegerValueRangeState, SingleStateElement>;
};
/// Integer range analysis determines the integer value range of SSA values
/// using operations that define `InferIntRangeInterface` and also sets the
/// range of iteration indices of loops with known bounds.
class IntegerRangeAnalysis
- : public SparseDataFlowAnalysis<IntegerValueRangeLattice> {
+ : public SparseDataFlowAnalysis<IntegerValueRangeState> {
public:
using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
/// Visit an operation. Invoke the transfer function on each operation that
/// implements `InferIntRangeInterface`.
- void visitOperation(Operation *op,
- ArrayRef<const IntegerValueRangeLattice *> operands,
- ArrayRef<IntegerValueRangeLattice *> results) override;
+ void
+ visitOperation(Operation *op,
+ ArrayRef<const IntegerValueRangeState *> operands,
+ ArrayRef<IntegerValueRangeState::ElementT *> results) override;
/// Visit block arguments or operation results of an operation with region
/// control-flow for which values are not defined by region control-flow. This
/// function calls `InferIntRangeInterface` to provide values for block
/// arguments or tries to reduce the range on loop induction variables with
/// known bounds.
- void
- visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor,
- ArrayRef<IntegerValueRangeLattice *> argLattices,
- unsigned firstIndex) override;
+ void visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor &successor,
+ ArrayRef<IntegerValueRangeState::ElementT *> argLattices,
+ unsigned firstIndex) override;
+};
+
+class IntegerRangeToConstant : public DataFlowAnalysis {
+public:
+ using DataFlowAnalysis::DataFlowAnalysis;
+
+ LogicalResult initialize(Operation *top) override;
+ LogicalResult visit(ProgramPoint point) override;
+
+ bool staticallyProvides(TypeID stateID, ProgramPoint point) const override {
+ return stateID == TypeID::get<ConstantValueState>() && point.is<Value>();
+ }
};
} // end namespace dataflow
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 27af1bc46b0d..13c5abf4acdb 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -23,67 +23,87 @@ namespace mlir {
namespace dataflow {
//===----------------------------------------------------------------------===//
-// AbstractSparseLattice
+// AbstractSparseState
//===----------------------------------------------------------------------===//
-/// This class represents an abstract lattice. A lattice contains information
-/// about an SSA value and is what's propagated across the IR by sparse
-/// data-flow analysis.
-class AbstractSparseLattice : public AnalysisState {
+class AbstractSparseState : public AbstractState {
public:
- /// Lattices can only be created for values.
- AbstractSparseLattice(Value value) : AnalysisState(value) {}
+ /// Join the information contained in 'rhs' into this state. Returns
+ /// if the value of the state changed.
+ virtual ChangeResult join(const AbstractSparseState &rhs) = 0;
- /// Join the information contained in 'rhs' into this lattice. Returns
- /// if the value of the lattice changed.
- virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0;
-
- /// Returns true if the lattice element is at fixpoint and further calls to
- /// `join` will not update the value of the element.
+ /// Returns true if the lattice state is at fixpoint and further calls to
+ /// `join` will not update the value of the state.
virtual bool isAtFixpoint() const = 0;
- /// Mark the lattice element as having reached a pessimistic fixpoint. This
- /// means that the lattice may potentially have conflicting value states, and
- /// only the most conservative value should be relied on.
+ /// Mark the lattice state as having reached a pessimistic fixpoint. This
+ /// means that the lattice may potentially have an overdefined or underdefined
+ /// value state, and only the most conservative value should be relied on.
virtual ChangeResult markPessimisticFixpoint() = 0;
- /// When the lattice gets updated, propagate an update to users of the value
- /// using its use-def chain to subscribed analyses.
- void onUpdate(DataFlowSolver *solver) const override;
+ /// Returns true if the value of this lattice hasn't yet been initialized.
+ virtual bool isUninitialized() const = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// AbstractSparseElement
+//===----------------------------------------------------------------------===//
+
+class AbstractSparseElement : public AbstractElement {
+public:
+ /// Sparse elements can only be created on SSA values.
+ explicit AbstractSparseElement(DataFlowSolver &solver, Value value)
+ : AbstractElement(solver, value) {}
- /// Subscribe an analysis to updates of the lattice. When the lattice changes,
- /// subscribed analyses are re-invoked on all users of the value. This is
- /// more efficient than relying on the dependency map.
- void useDefSubscribe(DataFlowAnalysis *analysis) {
+ virtual void useDefSubscribe(DataFlowAnalysis *analysis) = 0;
+
+ virtual const AbstractSparseState *get() const override = 0;
+};
+
+/// This class represents a sparse analysis element. A sparse element is
+/// attached to an SSA value and can track its dependents through the value's
+/// use-def chain. This is useful for improving the performance of sparse
+/// analyses where users are always dependents of SSA value elements.
+template <typename StateT, template <typename, typename> class BaseT>
+class SparseElement : public BaseT<StateT, AbstractSparseElement> {
+public:
+ using BaseT<StateT, AbstractSparseElement>::BaseT;
+
+ /// When the sparse element gets updated, propagate an update to users of the
+ /// value using its use-def chain to subscribed analyses.
+ void onUpdate() override {
+ for (Operation *user : this->point.template get<Value>().getUsers())
+ for (DataFlowAnalysis *analysis : useDefSubscribers)
+ this->solver.enqueue({user, analysis});
+ }
+
+ /// Subscribe an analysis to updates of the sparse element. When the element
+ /// changes, subscribed analyses are re-invoked on all users of the value.
+ /// This is more efficient than relying on the dependency map.
+ void useDefSubscribe(DataFlowAnalysis *analysis) override {
useDefSubscribers.insert(analysis);
}
private:
- /// A set of analyses that should be updated when this lattice changes.
+ /// A set of analyses that should be updated when this element changes.
SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
SmallPtrSet<DataFlowAnalysis *, 4>>
useDefSubscribers;
};
//===----------------------------------------------------------------------===//
-// Lattice
+// OptimisticSparseState
//===----------------------------------------------------------------------===//
-/// This class represents a lattice holding a specific value of type `ValueT`.
-/// Lattice values (`ValueT`) are required to adhere to the following:
-///
-/// * static ValueT join(const ValueT &lhs, const ValueT &rhs);
-/// - This method conservatively joins the information held by `lhs`
-/// and `rhs` into a new value. This method is required to be monotonic.
-/// * bool operator==(const ValueT &rhs) const;
-///
+/// This class represents a sparse state that has an optimistic and known value.
+/// This class should be used when the overdefined/underdefined value state is
+/// not finitely representable.
template <typename ValueT>
-class Lattice : public AbstractSparseLattice {
+class OptimisticSparseState : public AbstractSparseState {
public:
- /// Construct a lattice with a known value.
- explicit Lattice(Value value)
- : AbstractSparseLattice(value),
- knownValue(ValueT::getPessimisticValueState(value)) {}
+ template <typename PointT>
+ explicit OptimisticSparseState(PointT point)
+ : knownValue(ValueT::getPessimisticValue(point)) {}
/// Return the value held by this lattice. This requires that the value is
/// initialized.
@@ -92,16 +112,11 @@ class Lattice : public AbstractSparseLattice {
return *optimisticValue;
}
const ValueT &getValue() const {
- return const_cast<Lattice<ValueT> *>(this)->getValue();
+ return const_cast<OptimisticSparseState<ValueT> *>(this)->getValue();
}
/// Returns true if the value of this lattice hasn't yet been initialized.
bool isUninitialized() const override { return !optimisticValue.hasValue(); }
- /// Force the initialization of the element by setting it to its pessimistic
- /// fixpoint.
- ChangeResult defaultInitialize() override {
- return markPessimisticFixpoint();
- }
/// Returns true if the lattice has reached a fixpoint. A fixpoint is when
/// the information optimistically assumed to be true is the same as the
@@ -110,9 +125,8 @@ class Lattice : public AbstractSparseLattice {
/// Join the information contained in the 'rhs' lattice into this
/// lattice. Returns if the state of the current lattice changed.
- ChangeResult join(const AbstractSparseLattice &rhs) override {
- const Lattice<ValueT> &rhsLattice =
- static_cast<const Lattice<ValueT> &>(rhs);
+ ChangeResult join(const AbstractSparseState &rhs) override {
+ auto &rhsLattice = static_cast<const OptimisticSparseState<ValueT> &>(rhs);
// If we are at a fixpoint, or rhs is uninitialized, there is nothing to do.
if (isAtFixpoint() || rhsLattice.isUninitialized())
@@ -122,6 +136,21 @@ class Lattice : public AbstractSparseLattice {
return join(rhsLattice.getValue());
}
+ ChangeResult meet(const OptimisticSparseState<ValueT> &rhs) {
+ if (isUninitialized())
+ return ChangeResult::NoChange;
+ if (rhs.isUninitialized()) {
+ optimisticValue.reset();
+ return ChangeResult::Change;
+ }
+ ValueT newValue = ValueT::meet(getValue(), rhs.getValue());
+ if (newValue == optimisticValue)
+ return ChangeResult::NoChange;
+
+ optimisticValue = newValue;
+ return ChangeResult::Change;
+ }
+
/// Join the information contained in the 'rhs' value into this
/// lattice. Returns if the state of the current lattice changed.
ChangeResult join(const ValueT &rhs) {
@@ -159,16 +188,14 @@ class Lattice : public AbstractSparseLattice {
return ChangeResult::Change;
}
- /// Print the lattice element.
void print(raw_ostream &os) const override {
- os << "[";
+ os << '[';
knownValue.print(os);
- os << ", ";
- if (optimisticValue)
+ if (optimisticValue) {
+ os << ", ";
optimisticValue->print(os);
- else
- os << "<NULL>";
- os << "]";
+ }
+ os << ']';
}
private:
@@ -206,10 +233,10 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
/// The operation transfer function. Given the operand lattices, this
/// function is expected to set the result lattices.
- virtual void
- visitOperationImpl(Operation *op,
- ArrayRef<const AbstractSparseLattice *> operandLattices,
- ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
+ virtual void visitOperationImpl(
+ Operation *op,
+ ArrayRef<const AbstractSparseState *> operandLattices,
+ ArrayRef<AbstractSparseElement *> resultLattices) = 0;
/// Given an operation with region control-flow, the lattices of the operands,
/// and a region successor, compute the lattice values for block arguments
@@ -217,26 +244,29 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
/// of loops).
virtual void visitNonControlFlowArgumentsImpl(
Operation *op, const RegionSuccessor &successor,
- ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0;
+ ArrayRef<AbstractSparseElement *> argLattices,
+ unsigned firstIndex) = 0;
/// Get the lattice element of a value.
- virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
+ virtual AbstractSparseElement *getLatticeElement(Value value) = 0;
/// Get a read-only lattice element for a value and add it as a dependency to
/// a program point.
- const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
- Value value);
+ const AbstractSparseState *getLatticeElementFor(ProgramPoint point,
+ Value value);
/// Mark a lattice element as having reached its pessimistic fixpoint and
/// propgate an update if changed.
- void markPessimisticFixpoint(AbstractSparseLattice *lattice);
+ void markPessimisticFixpoint(AbstractSparseElement *element);
/// Mark the given lattice elements as having reached their pessimistic
/// fixpoints and propagate an update if any changed.
- void markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice *> lattices);
+ void markAllPessimisticFixpoint(
+ ArrayRef<AbstractSparseElement *> elements);
/// Join the lattice element and propagate and update if it changed.
- void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
+ void join(AbstractSparseElement *lhs,
+ const AbstractSparseState &rhs);
private:
/// Recursively initialize the analysis on nested operations and blocks.
@@ -255,9 +285,10 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
/// operation `branch`, which can either be the entry block of one of the
/// regions or the parent operation itself, and set either the argument or
/// parent result lattices.
- void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
- Optional<unsigned> successorIndex,
- ArrayRef<AbstractSparseLattice *> lattices);
+ void
+ visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
+ Optional<unsigned> successorIndex,
+ ArrayRef<AbstractSparseElement *> elements);
};
//===----------------------------------------------------------------------===//
@@ -267,7 +298,7 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
/// A sparse (forward) data-flow analysis for propagating SSA value lattices
/// across the IR by implementing transfer functions for operations.
///
-/// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
+/// `StateT` is expected to be a subclass of `AbstractSparseState`.
template <typename StateT>
class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
public:
@@ -276,8 +307,9 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
/// Visit an operation with the lattices of its operands. This function is
/// expected to set the lattices of the operation's results.
- virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
- ArrayRef<StateT *> results) = 0;
+ virtual void
+ visitOperation(Operation *op, ArrayRef<const StateT *> operands,
+ ArrayRef<typename StateT::ElementT *> results) = 0;
/// Given an operation with possible region control-flow, the lattices of the
/// operands, and a region successor, compute the lattice values for block
@@ -285,18 +317,21 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
/// the bounds of loops). By default, this method marks all such lattice
/// elements as having reached a pessimistic fixpoint. `firstIndex` is the
/// index of the first element of `argLattices` that is set by control-flow.
- virtual void visitNonControlFlowArguments(Operation *op,
- const RegionSuccessor &successor,
- ArrayRef<StateT *> argLattices,
- unsigned firstIndex) {
+ virtual void visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor &successor,
+ ArrayRef<typename StateT::ElementT *> argLattices, unsigned firstIndex) {
markAllPessimisticFixpoint(argLattices.take_front(firstIndex));
markAllPessimisticFixpoint(argLattices.drop_front(
firstIndex + successor.getSuccessorInputs().size()));
}
protected:
+ bool staticallyProvides(TypeID stateID, ProgramPoint point) const override {
+ return stateID == TypeID::get<StateT>() && point.is<Value>();
+ }
+
/// Get the lattice element for a value.
- StateT *getLatticeElement(Value value) override {
+ typename StateT::ElementT *getLatticeElement(Value value) override {
return getOrCreate<StateT>(value);
}
@@ -309,32 +344,37 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
/// Mark the lattice elements of a range of values as having reached their
/// pessimistic fixpoint.
- void markAllPessimisticFixpoint(ArrayRef<StateT *> lattices) {
+ void
+ markAllPessimisticFixpoint(ArrayRef<typename StateT::ElementT *> elements) {
AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
- {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
- lattices.size()});
+ {reinterpret_cast<AbstractSparseElement *const *>(
+ elements.begin()),
+ elements.size()});
}
private:
/// Type-erased wrappers that convert the abstract lattice operands to derived
/// lattices and invoke the virtual hooks operating on the derived lattices.
void visitOperationImpl(
- Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
- ArrayRef<AbstractSparseLattice *> resultLattices) override {
+ Operation *op,
+ ArrayRef<const AbstractSparseState *> operandLattices,
+ ArrayRef<AbstractSparseElement *> resultLattices) override {
visitOperation(
op,
{reinterpret_cast<const StateT *const *>(operandLattices.begin()),
operandLattices.size()},
- {reinterpret_cast<StateT *const *>(resultLattices.begin()),
+ {reinterpret_cast<typename StateT::ElementT *const *>(
+ resultLattices.begin()),
resultLattices.size()});
}
void visitNonControlFlowArgumentsImpl(
Operation *op, const RegionSuccessor &successor,
- ArrayRef<AbstractSparseLattice *> argLattices,
+ ArrayRef<AbstractSparseElement *> argLattices,
unsigned firstIndex) override {
visitNonControlFlowArguments(
op, successor,
- {reinterpret_cast<StateT *const *>(argLattices.begin()),
+ {reinterpret_cast<typename StateT::ElementT *const *>(
+ argLattices.begin()),
argLattices.size()},
firstIndex);
}
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 2992e05f14dd..ab898a70a8bd 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -45,9 +45,6 @@ inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
return lhs == ChangeResult::NoChange ? lhs : rhs;
}
-/// Forward declare the analysis state class.
-class AnalysisState;
-
//===----------------------------------------------------------------------===//
// GenericProgramPoint
//===----------------------------------------------------------------------===//
@@ -178,6 +175,8 @@ class DataFlowAnalysis;
// DataFlowSolver
//===----------------------------------------------------------------------===//
+class AbstractElement;
+
/// The general data-flow analysis solver. This class is responsible for
/// orchestrating child data-flow analyses, running the fixed-point iteration
/// algorithm, managing analysis state and program point memory, and tracking
@@ -202,16 +201,19 @@ class DataFlowSolver {
/// operation and run the analysis until fixpoint.
LogicalResult initializeAndRun(Operation *top);
- /// Lookup an analysis state for the given program point. Returns null if one
- /// does not exist.
template <typename StateT, typename PointT>
- const StateT *lookupState(PointT point) const {
- auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
- if (it == analysisStates.end())
+ const StateT *lookup(PointT point) const {
+ using ElementT = typename StateT::ElementT;
+ auto it = elements.find({TypeID::get<StateT>(), ProgramPoint(point)});
+ if (it == elements.end())
return nullptr;
- return static_cast<const StateT *>(it->second.get());
+ return static_cast<const ElementT &>(*it->second).get();
}
+ template <typename StateT, typename PointT>
+ typename StateT::ElementT *getOrCreate(PointT point);
+
+public:
/// Get a uniqued program point instance. If one is not present, it is
/// created with the provided arguments.
template <typename PointT, typename... Args>
@@ -226,20 +228,9 @@ class DataFlowSolver {
/// Push a work item onto the worklist.
void enqueue(WorkItem item) { worklist.push(std::move(item)); }
- /// Get the state associated with the given program point. If it does not
- /// exist, create an uninitialized state.
- template <typename StateT, typename PointT>
- StateT *getOrCreateState(PointT point);
-
- /// Propagate an update to an analysis state if it changed by pushing
- /// dependent work items to the back of the queue.
- void propagateIfChanged(AnalysisState *state, ChangeResult changed);
-
- /// Add a dependency to an analysis state on a child analysis and program
- /// point. If the state is updated, the child analysis must be invoked on the
- /// given program point again.
- void addDependency(AnalysisState *state, DataFlowAnalysis *analysis,
- ProgramPoint point);
+ void getStaticProvidersFor(
+ TypeID stateID, ProgramPoint point,
+ SmallVectorImpl<DataFlowAnalysis *> &staticProviders) const;
private:
/// The solver's work queue. Work items can be inserted to the front of the
@@ -254,78 +245,129 @@ class DataFlowSolver {
/// points.
StorageUniquer uniquer;
- /// A type-erased map of program points to associated analysis states for
- /// first-class program points.
- DenseMap<std::pair<ProgramPoint, TypeID>, std::unique_ptr<AnalysisState>>
- analysisStates;
+ /// A type-erased map of program points to associated analysis states.
+ DenseMap<std::pair<TypeID, ProgramPoint>,
+ std::unique_ptr<AbstractElement>>
+ elements;
/// Allow the base child analysis class to access the internals of the solver.
friend class DataFlowAnalysis;
};
//===----------------------------------------------------------------------===//
-// AnalysisState
+// AbstractElement
//===----------------------------------------------------------------------===//
-/// Base class for generic analysis states. Analysis states contain data-flow
-/// information that are attached to program points and which evolve as the
-/// analysis iterates.
-///
-/// This class places no restrictions on the semantics of analysis states beyond
-/// these requirements.
-///
-/// 1. Querying the state of a program point prior to visiting that point
-/// results in uninitialized state. Analyses must be aware of unintialized
-/// states.
-/// 2. Analysis states can reach fixpoints, where subsequent updates will never
-/// trigger a change in the state.
-/// 3. Analysis states that are uninitialized can be forcefully initialized to a
-/// default value.
-class AnalysisState {
+class AbstractState {
public:
- virtual ~AnalysisState();
+ virtual ~AbstractState();
- /// Create the analysis state at the given program point.
- AnalysisState(ProgramPoint point) : point(point) {}
+ virtual void print(raw_ostream &os) const = 0;
+};
- /// Returns true if the analysis state is uninitialized.
- virtual bool isUninitialized() const = 0;
+/// Subclasses are required to implement `get` and `update`.
+class AbstractElement {
+public:
+ virtual ~AbstractElement();
- /// Force an uninitialized analysis state to initialize itself with a default
- /// value.
- virtual ChangeResult defaultInitialize() = 0;
+ explicit AbstractElement(DataFlowSolver &solver, ProgramPoint point)
+ : solver(solver), point(point) {}
- /// Print the contents of the analysis state.
- virtual void print(raw_ostream &os) const = 0;
+ void addDependency(DataFlowAnalysis *analysis, ProgramPoint point);
+
+ virtual const AbstractState *get() const = 0;
+ virtual void update(DataFlowAnalysis *provider,
+ function_ref<ChangeResult(AbstractState *)> updateFn) = 0;
protected:
- /// This function is called by the solver when the analysis state is updated
- /// to optionally enqueue more work items. For example, if a state tracks
- /// dependents through the IR (e.g. use-def chains), this function can be
- /// implemented to push those dependents on the worklist.
- virtual void onUpdate(DataFlowSolver *solver) const {}
-
- /// The dependency relations originating from this analysis state. An entry
- /// `state -> (analysis, point)` is created when `analysis` queries `state`
- /// when updating `point`.
- ///
- /// When this state is updated, all dependent child analysis invocations are
- /// pushed to the back of the queue. Use a `SetVector` to keep the analysis
- /// deterministic.
- ///
- /// Store the dependents on the analysis state for efficiency.
- SetVector<DataFlowSolver::WorkItem> dependents;
+ void propagateUpdate();
+
+ virtual void onUpdate() {}
- /// The program point to which the state belongs.
+ DataFlowSolver &solver;
ProgramPoint point;
+private:
+ SetVector<DataFlowSolver::WorkItem, SmallVector<DataFlowSolver::WorkItem>,
+ llvm::SmallDenseSet<DataFlowSolver::WorkItem>>
+ dependents;
+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- /// When compiling with debugging, keep a name for the analysis state.
+ /// When compiling with debugging, keep a name for the element.
StringRef debugName;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
- /// Allow the framework to access the dependents.
- friend class DataFlowSolver;
+ friend class ::mlir::DataFlowSolver;
+};
+
+template <typename StateT, typename BaseT = AbstractElement>
+class SingleStateElement : public BaseT {
+public:
+ template <typename PointT>
+ explicit SingleStateElement(DataFlowSolver &solver, PointT point)
+ : BaseT(solver, point), state(point) {}
+
+ const StateT *get() const override { return &state; }
+
+ void update(DataFlowAnalysis *provider,
+ function_ref<ChangeResult(AbstractState *)> updateFn) override {
+ if (updateFn(&state) == ChangeResult::Change)
+ BaseT::propagateUpdate();
+ }
+ void update(DataFlowAnalysis *provider,
+ function_ref<ChangeResult(StateT *)> updateFn) {
+ return update(provider, function_ref<ChangeResult(AbstractState *)>(
+ [updateFn](AbstractState *state) {
+ return updateFn(static_cast<StateT *>(state));
+ }));
+ }
+
+private:
+ StateT state;
+};
+
+/// StateT is required to implement `join` and `meet`.
+template <typename StateT, typename BaseT = AbstractElement>
+class MultiStateElement : public BaseT {
+public:
+ template <typename PointT>
+ explicit MultiStateElement(DataFlowSolver &solver, PointT point)
+ : BaseT(solver, point), state(point) {
+ SmallVector<DataFlowAnalysis *, 2> staticProviders;
+ solver.getStaticProvidersFor(TypeID::get<StateT>(), point, staticProviders);
+ for (DataFlowAnalysis *staticProvider : staticProviders)
+ states.try_emplace(staticProvider, StateT(point));
+ }
+
+ const StateT *get() const override { return &state; }
+
+ void update(DataFlowAnalysis *provider,
+ function_ref<ChangeResult(AbstractState *)> updateFn) override {
+ auto it = states.find(provider);
+ if (it == states.end()) {
+ if (updateFn(&state) == ChangeResult::Change)
+ BaseT::propagateUpdate();
+ return;
+ }
+ if (updateFn(&it->second) == ChangeResult::NoChange)
+ return;
+ StateT newState(it->second);
+ for (auto &entry : states)
+ (void)newState.meet(entry.second);
+ if (state.join(newState) == ChangeResult::Change)
+ BaseT::propagateUpdate();
+ }
+ void update(DataFlowAnalysis *provider,
+ function_ref<ChangeResult(StateT *)> updateFn) {
+ return update(provider, function_ref<ChangeResult(AbstractState *)>(
+ [updateFn](AbstractState *state) {
+ return updateFn(static_cast<StateT *>(state));
+ }));
+ }
+
+private:
+ StateT state;
+ llvm::SmallDenseMap<DataFlowAnalysis *, StateT, 2> states;
};
//===----------------------------------------------------------------------===//
@@ -385,12 +427,13 @@ class DataFlowAnalysis {
virtual LogicalResult visit(ProgramPoint point) = 0;
protected:
- /// Create a dependency between the given analysis state and program point
- /// on this analysis.
- void addDependency(AnalysisState *state, ProgramPoint point);
-
- /// Propagate an update to a state if it changed.
- void propagateIfChanged(AnalysisState *state, ChangeResult changed);
+ /// Returns true if this analysis *statically* provides values for the given
+ /// state kind for the given program point. This means the analysis will
+ /// always provide values for this state regardless of the state of the
+ /// analysis.
+ virtual bool staticallyProvides(TypeID stateID, ProgramPoint point) const {
+ return false;
+ }
/// Register a custom program point class.
template <typename PointT>
@@ -404,12 +447,18 @@ class DataFlowAnalysis {
return solver.getProgramPoint<PointT>(std::forward<Args>(args)...);
}
+ template <typename StateT, typename PointT>
+ void update(PointT point, function_ref<ChangeResult(StateT *)> updateFn) {
+ auto *element = getOrCreate<StateT>(point);
+ element->update(this, updateFn);
+ }
+
/// Get the analysis state assiocated with the program point. The returned
/// state is expected to be "write-only", and any updates need to be
/// propagated by `propagateIfChanged`.
template <typename StateT, typename PointT>
- StateT *getOrCreate(PointT point) {
- return solver.getOrCreateState<StateT>(point);
+ typename StateT::ElementT *getOrCreate(PointT point) {
+ return solver.getOrCreate<StateT>(point);
}
/// Get a read-only analysis state for the given point and create a dependency
@@ -417,14 +466,15 @@ class DataFlowAnalysis {
/// re-invoked on the dependent.
template <typename StateT, typename PointT>
const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) {
- StateT *state = getOrCreate<StateT>(point);
- addDependency(state, dependent);
- return state;
+ auto *element = getOrCreate<StateT>(point);
+ element->addDependency(this, dependent);
+ return element->get();
}
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// When compiling with debugging, keep a name for the analyis.
StringRef debugName;
+ friend class AbstractElement;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
private:
@@ -445,19 +495,21 @@ AnalysisT *DataFlowSolver::load(Args &&...args) {
}
template <typename StateT, typename PointT>
-StateT *DataFlowSolver::getOrCreateState(PointT point) {
- std::unique_ptr<AnalysisState> &state =
- analysisStates[{ProgramPoint(point), TypeID::get<StateT>()}];
- if (!state) {
- state = std::unique_ptr<StateT>(new StateT(point));
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- state->debugName = llvm::getTypeName<StateT>();
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+typename StateT::ElementT *DataFlowSolver::getOrCreate(PointT point) {
+ using ElementT = typename StateT::ElementT;
+ static_assert(std::is_base_of<AbstractElement, ElementT>::value,
+ "expected an abstract element");
+ std::unique_ptr<AbstractElement> &element =
+ elements[{TypeID::get<StateT>(), ProgramPoint(point)}];
+ if (!element) {
+ element = std::unique_ptr<ElementT>(new ElementT(*this, point));
+ element->debugName = llvm::getTypeName<StateT>();
}
- return static_cast<StateT *>(state.get());
+ return static_cast<ElementT *>(element.get());
}
-inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
+inline raw_ostream &operator<<(raw_ostream &os,
+ const AbstractState &state) {
state.print(os);
return os;
}
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index 386237e47b5e..39affbdcfb34 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -30,8 +30,8 @@ void ConstantValue::print(raw_ostream &os) const {
//===----------------------------------------------------------------------===//
void SparseConstantPropagation::visitOperation(
- Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
- ArrayRef<Lattice<ConstantValue> *> results) {
+ Operation *op, ArrayRef<const ConstantValueState *> operands,
+ ArrayRef<ConstantValueState::ElementT *> results) {
LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
// Don't try to simulate the results of a region operation as we can't
@@ -39,12 +39,12 @@ void SparseConstantPropagation::visitOperation(
// folds as the desire here is for simulated execution, and not general
// folding.
if (op->getNumRegions())
- return;
+ return markAllPessimisticFixpoint(results);
SmallVector<Attribute, 8> constantOperands;
constantOperands.reserve(op->getNumOperands());
- for (auto *operandLattice : operands)
- constantOperands.push_back(operandLattice->getValue().getConstantValue());
+ for (auto *operandState : operands)
+ constantOperands.push_back(operandState->getValue().getConstantValue());
// Save the original operands and attributes just in case the operation
// folds in-place. The constant passed in may not correspond to the real
@@ -56,10 +56,8 @@ void SparseConstantPropagation::visitOperation(
// fails or was an in-place fold, mark the results as overdefined.
SmallVector<OpFoldResult, 8> foldResults;
foldResults.reserve(op->getNumResults());
- if (failed(op->fold(constantOperands, foldResults))) {
- markAllPessimisticFixpoint(results);
- return;
- }
+ if (failed(op->fold(constantOperands, foldResults)))
+ return markAllPessimisticFixpoint(results);
// If the folding was in-place, mark the results as overdefined and reset
// the operation. We don't allow in-place folds as the desire here is for
@@ -67,25 +65,26 @@ void SparseConstantPropagation::visitOperation(
if (foldResults.empty()) {
op->setOperands(originalOperands);
op->setAttrs(originalAttrs);
- return;
+ return markAllPessimisticFixpoint(results);
}
// Merge the fold results into the lattice for this operation.
assert(foldResults.size() == op->getNumResults() && "invalid result size");
for (const auto it : llvm::zip(results, foldResults)) {
- Lattice<ConstantValue> *lattice = std::get<0>(it);
+ ConstantValueState::ElementT *element = std::get<0>(it);
// Merge in the result of the fold, either a constant or a value.
OpFoldResult foldResult = std::get<1>(it);
if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
- propagateIfChanged(lattice,
- lattice->join(ConstantValue(attr, op->getDialect())));
+ element->update(this, [attr, op](ConstantValueState *state) {
+ return state->join(ConstantValue(attr, op->getDialect()));
+ });
} else {
LLVM_DEBUG(llvm::dbgs()
<< "Folded to value: " << foldResult.get<Value>() << "\n");
AbstractSparseDataFlowAnalysis::join(
- lattice, *getLatticeElement(foldResult.get<Value>()));
+ element, *getLatticeElement(foldResult.get<Value>())->get());
}
}
}
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 5e5215144c39..6ff6fb4c3735 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -27,24 +27,6 @@ void Executable::print(raw_ostream &os) const {
os << (live ? "live" : "dead");
}
-void Executable::onUpdate(DataFlowSolver *solver) const {
- if (auto *block = point.dyn_cast<Block *>()) {
- // Re-invoke the analyses on the block itself.
- for (DataFlowAnalysis *analysis : subscribers)
- solver->enqueue({block, analysis});
- // Re-invoke the analyses on all operations in the block.
- for (DataFlowAnalysis *analysis : subscribers)
- for (Operation &op : *block)
- solver->enqueue({&op, analysis});
- } else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
- // Re-invoke the analysis on the successor block.
- if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
- for (DataFlowAnalysis *analysis : subscribers)
- solver->enqueue({edge->getTo(), analysis});
- }
- }
-}
-
//===----------------------------------------------------------------------===//
// PredecessorState
//===----------------------------------------------------------------------===//
@@ -104,8 +86,8 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
for (Region ®ion : top->getRegions()) {
if (region.empty())
continue;
- auto *state = getOrCreate<Executable>(®ion.front());
- propagateIfChanged(state, state->setToLive());
+ update<Executable>(®ion.front(),
+ [](Executable *state) { return state->setToLive(); });
}
// Mark as overdefined the predecessors of symbol callables with potentially
@@ -132,8 +114,9 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
// Public symbol callables or those for which we can't see all uses have
// potentially unknown callsites.
if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
- auto *state = getOrCreate<PredecessorState>(callable);
- propagateIfChanged(state, state->setHasUnknownPredecessors());
+ update<PredecessorState>(callable, [](PredecessorState *state) {
+ return state->setHasUnknownPredecessors();
+ });
}
foundSymbolCallable = true;
}
@@ -149,8 +132,9 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
// If we couldn't gather the symbol uses, conservatively assume that
// we can't track information for any nested symbols.
return top->walk([&](CallableOpInterface callable) {
- auto *state = getOrCreate<PredecessorState>(callable);
- propagateIfChanged(state, state->setHasUnknownPredecessors());
+ update<PredecessorState>(callable, [](PredecessorState *state) {
+ return state->setHasUnknownPredecessors();
+ });
});
}
@@ -160,12 +144,12 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
// If a callable symbol has a non-call use, then we can't be guaranteed to
// know all callsites.
Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
- auto *state = getOrCreate<PredecessorState>(symbol);
- propagateIfChanged(state, state->setHasUnknownPredecessors());
+ update<PredecessorState>(symbol, [](PredecessorState *state) {
+ return state->setHasUnknownPredecessors();
+ });
}
};
- SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
- walkFn);
+ SymbolTable::walkSymbolTables(top, !top->getBlock(), walkFn);
}
/// Returns true if the operation terminates a block. It is insufficient to
@@ -198,18 +182,17 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
}
void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
- auto *state = getOrCreate<Executable>(to);
- propagateIfChanged(state, state->setToLive());
- auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
- propagateIfChanged(edgeState, edgeState->setToLive());
+ update<Executable>(to, [](Executable *state) { return state->setToLive(); });
+ update<Executable>(getProgramPoint<CFGEdge>(from, to),
+ [](Executable *state) { return state->setToLive(); });
}
void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
for (Region ®ion : op->getRegions()) {
if (region.empty())
continue;
- auto *state = getOrCreate<Executable>(®ion.front());
- propagateIfChanged(state, state->setToLive());
+ update<Executable>(®ion.front(),
+ [](Executable *state) { return state->setToLive(); });
}
}
@@ -221,7 +204,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
return emitError(point.getLoc(), "unknown program point kind");
// If the parent block is not executable, there is nothing to do.
- if (!getOrCreate<Executable>(op->getBlock())->isLive())
+ if (!getOrCreate<Executable>(op->getBlock())->get()->isLive())
return success();
// We have a live call op. Add this as a live predecessor of the callee.
@@ -296,25 +279,27 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
!isExternalCallable(callableOp)) {
// Add the live callsite.
- auto *callsites = getOrCreate<PredecessorState>(callableOp);
- propagateIfChanged(callsites, callsites->join(call));
+ update<PredecessorState>(callableOp, [call](PredecessorState *state) {
+ return state->join(call);
+ });
} else {
// Mark this call op's predecessors as overdefined.
- auto *predecessors = getOrCreate<PredecessorState>(call);
- propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
+ update<PredecessorState>(call, [](PredecessorState *state) {
+ return state->setHasUnknownPredecessors();
+ });
}
}
/// Get the constant values of the operands of an operation. If any of the
/// constant value lattices are uninitialized, return none to indicate the
/// analysis should bail out.
-static Optional<SmallVector<Attribute>> getOperandValuesImpl(
- Operation *op,
- function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
+static Optional<SmallVector<Attribute>>
+getOperandValuesImpl(Operation *op,
+ function_ref<const ConstantValueState *(Value)> getState) {
SmallVector<Attribute> operands;
operands.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
- const Lattice<ConstantValue> *cv = getLattice(operand);
+ const ConstantValueState *cv = getState(operand);
// If any of the operands' values are uninitialized, bail out.
if (cv->isUninitialized())
return {};
@@ -325,11 +310,12 @@ static Optional<SmallVector<Attribute>> getOperandValuesImpl(
Optional<SmallVector<Attribute>>
DeadCodeAnalysis::getOperandValues(Operation *op) {
- return getOperandValuesImpl(op, [&](Value value) {
- auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
- lattice->useDefSubscribe(this);
- return lattice;
- });
+ return getOperandValuesImpl(
+ op, [&](Value value) -> const ConstantValueState * {
+ auto *element = getOrCreate<ConstantValueState>(value);
+ element->useDefSubscribe(this);
+ return element->get();
+ });
}
void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
@@ -362,13 +348,13 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
? &successor.getSuccessor()->front()
: ProgramPoint(branch);
// Mark the entry block as executable.
- auto *state = getOrCreate<Executable>(point);
- propagateIfChanged(state, state->setToLive());
+ update<Executable>(point,
+ [](Executable *state) { return state->setToLive(); });
// Add the parent op as a predecessor.
- auto *predecessors = getOrCreate<PredecessorState>(point);
- propagateIfChanged(
- predecessors,
- predecessors->join(branch, successor.getSuccessorInputs()));
+ update<PredecessorState>(
+ point, [branch, &successor](PredecessorState *state) {
+ return state->join(branch, successor.getSuccessorInputs());
+ });
}
}
@@ -385,17 +371,20 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
// Mark successor region entry blocks as executable and add this op to the
// list of predecessors.
for (const RegionSuccessor &successor : successors) {
- PredecessorState *predecessors;
if (Region *region = successor.getSuccessor()) {
- auto *state = getOrCreate<Executable>(®ion->front());
- propagateIfChanged(state, state->setToLive());
- predecessors = getOrCreate<PredecessorState>(®ion->front());
+ update<Executable>(®ion->front(),
+ [](Executable *state) { return state->setToLive(); });
+ update<PredecessorState>(
+ ®ion->front(), [op, &successor](PredecessorState *state) {
+ return state->join(op, successor.getSuccessorInputs());
+ });
} else {
// Add this terminator as a predecessor to the parent op.
- predecessors = getOrCreate<PredecessorState>(branch);
+ update<PredecessorState>(
+ branch, [op, &successor](PredecessorState *state) {
+ return state->join(op, successor.getSuccessorInputs());
+ });
}
- propagateIfChanged(predecessors,
- predecessors->join(op, successor.getSuccessorInputs()));
}
}
@@ -412,12 +401,14 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
assert(isa<CallOpInterface>(predecessor));
auto *predecessors = getOrCreate<PredecessorState>(predecessor);
if (canResolve) {
- propagateIfChanged(predecessors, predecessors->join(op));
+ predecessors->update(
+ this, [op](PredecessorState *state) { return state->join(op); });
} else {
// If the terminator is not a return-like, then conservatively assume we
// can't resolve the predecessor.
- propagateIfChanged(predecessors,
- predecessors->setHasUnknownPredecessors());
+ predecessors->update(this, [](PredecessorState *state) {
+ return state->setHasUnknownPredecessors();
+ });
}
}
}
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 95c4aadc7592..957880f6c999 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -45,8 +45,8 @@ void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) {
return;
// Get the dense lattice to update.
- AbstractDenseLattice *after = getLattice(op);
- if (after->isAtFixpoint())
+ AbstractDenseElement *after = getLattice(op);
+ if (after->get()->isAtFixpoint())
return;
// If this op implements region control-flow, then control-flow dictates its
@@ -61,14 +61,17 @@ void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) {
// If not all return sites are known, then conservatively assume we can't
// reason about the data-flow.
if (!predecessors->allPredecessorsKnown())
- return reset(after);
- for (Operation *predecessor : predecessors->getKnownPredecessors())
- join(after, *getLatticeFor(op, predecessor));
- return;
+ return markPessimisticFixpoint(after);
+ return update(after, [this, predecessors, op](AbstractDenseState *state) {
+ ChangeResult result = ChangeResult::NoChange;
+ for (Operation *predecessor : predecessors->getKnownPredecessors())
+ result |= state->join(*getLatticeFor(op, predecessor));
+ return result;
+ });
}
// Get the dense state before the execution of the op.
- const AbstractDenseLattice *before;
+ const AbstractDenseState *before;
if (Operation *prev = op->getPrevNode())
before = getLatticeFor(op, prev);
else
@@ -87,8 +90,8 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
return;
// Get the dense lattice to update.
- AbstractDenseLattice *after = getLattice(block);
- if (after->isAtFixpoint())
+ AbstractDenseElement *after = getLattice(block);
+ if (after->get()->isAtFixpoint())
return;
// The dense lattices of entry blocks are set by region control-flow or the
@@ -101,15 +104,17 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints.
if (!callsites->allPredecessorsKnown())
- return reset(after);
- for (Operation *callsite : callsites->getKnownPredecessors()) {
- // Get the dense lattice before the callsite.
- if (Operation *prev = callsite->getPrevNode())
- join(after, *getLatticeFor(block, prev));
- else
- join(after, *getLatticeFor(block, callsite->getBlock()));
- }
- return;
+ return markPessimisticFixpoint(after);
+ return update(after, [this, callsites, block](AbstractDenseState *state) {
+ ChangeResult result = ChangeResult::NoChange;
+ for (Operation *callsite : callsites->getKnownPredecessors()) {
+ if (Operation *prev = callsite->getPrevNode())
+ result |= state->join(*getLatticeFor(block, prev));
+ else
+ result |= state->join(*getLatticeFor(block, callsite->getBlock()));
+ }
+ return result;
+ });
}
// Check if we can reason about the control-flow.
@@ -117,53 +122,62 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
return visitRegionBranchOperation(block, branch, after);
// Otherwise, we can't reason about the data-flow.
- return reset(after);
+ return markPessimisticFixpoint(after);
}
// Join the state with the state after the block's predecessors.
- for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
- it != e; ++it) {
- // Skip control edges that aren't executable.
- Block *predecessor = *it;
- if (!getOrCreateFor<Executable>(
- block, getProgramPoint<CFGEdge>(predecessor, block))
- ->isLive())
- continue;
-
- // Merge in the state from the predecessor's terminator.
- join(after, *getLatticeFor(block, predecessor->getTerminator()));
- }
+ update(after, [this, block](AbstractDenseState *state) {
+ ChangeResult result = ChangeResult::NoChange;
+ for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
+ it != e; ++it) {
+ // Skip control edges that aren't executable.
+ Block *predecessor = *it;
+ if (!getOrCreateFor<Executable>(
+ block, getProgramPoint<CFGEdge>(predecessor, block))
+ ->isLive())
+ continue;
+
+ // Merge in the state from the predecessor's terminator.
+ result |=
+ state->join(*getLatticeFor(block, predecessor->getTerminator()));
+ }
+ return result;
+ });
}
void AbstractDenseDataFlowAnalysis::visitRegionBranchOperation(
ProgramPoint point, RegionBranchOpInterface branch,
- AbstractDenseLattice *after) {
+ AbstractDenseElement *after) {
// Get the terminator predecessors.
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
- for (Operation *op : predecessors->getKnownPredecessors()) {
- const AbstractDenseLattice *before;
- // If the predecessor is the parent, get the state before the parent.
- if (op == branch) {
- if (Operation *prev = op->getPrevNode())
- before = getLatticeFor(point, prev);
- else
- before = getLatticeFor(point, op->getBlock());
-
- // Otherwise, get the state after the terminator.
- } else {
- before = getLatticeFor(point, op);
+ update(after, [&](AbstractDenseState *state) {
+ ChangeResult result = ChangeResult::NoChange;
+ for (Operation *op : predecessors->getKnownPredecessors()) {
+ const AbstractDenseState *before;
+ // If the predecessor is the parent, get the state before the parent.
+ if (op == branch) {
+ if (Operation *prev = op->getPrevNode())
+ before = getLatticeFor(point, prev);
+ else
+ before = getLatticeFor(point, op->getBlock());
+
+ // Otherwise, get the state after the terminator.
+ } else {
+ before = getLatticeFor(point, op);
+ }
+ result |= state->join(*before);
}
- join(after, *before);
- }
+ return result;
+ });
}
-const AbstractDenseLattice *
+const AbstractDenseState *
AbstractDenseDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
ProgramPoint point) {
- AbstractDenseLattice *state = getLattice(point);
- addDependency(state, dependent);
- return state;
+ AbstractDenseElement *element = getLattice(point);
+ element->addDependency(this, dependent);
+ return element->get();
}
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index e983341faf02..434f7d44f751 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -23,7 +23,7 @@
using namespace mlir;
using namespace mlir::dataflow;
-IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
+IntegerValueRange IntegerValueRange::getPessimisticValue(Value value) {
unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
APInt umin = APInt::getMinValue(width);
APInt umax = APInt::getMaxValue(width);
@@ -32,30 +32,9 @@ IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
return {{umin, umax, smin, smax}};
}
-void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
- Lattice::onUpdate(solver);
-
- // If the integer range can be narrowed to a constant, update the constant
- // value of the SSA value.
- Optional<APInt> constant = getValue().getValue().getConstantValue();
- auto value = point.get<Value>();
- auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
- if (!constant)
- return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint());
-
- Dialect *dialect;
- if (auto *parent = value.getDefiningOp())
- dialect = parent->getDialect();
- else
- dialect = value.getParentBlock()->getParentOp()->getDialect();
- solver->propagateIfChanged(
- cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
- dialect)));
-}
-
void IntegerRangeAnalysis::visitOperation(
- Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
- ArrayRef<IntegerValueRangeLattice *> results) {
+ Operation *op, ArrayRef<const IntegerValueRangeState *> operands,
+ ArrayRef<IntegerValueRangeState::ElementT *> results) {
// Ignore non-integer outputs - return early if the op has no scalar
// integer results
bool hasIntegerResult = false;
@@ -63,8 +42,9 @@ void IntegerRangeAnalysis::visitOperation(
if (std::get<1>(it).getType().isIntOrIndex()) {
hasIntegerResult = true;
} else {
- propagateIfChanged(std::get<0>(it),
- std::get<0>(it)->markPessimisticFixpoint());
+ std::get<0>(it)->update(this, [](IntegerValueRangeState *state) {
+ return state->markPessimisticFixpoint();
+ });
}
}
if (!hasIntegerResult)
@@ -76,7 +56,7 @@ void IntegerRangeAnalysis::visitOperation(
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
SmallVector<ConstantIntRanges> argRanges(
- llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
+ llvm::map_range(operands, [](const IntegerValueRangeState *val) {
return val->getValue().getValue();
}));
@@ -87,26 +67,28 @@ void IntegerRangeAnalysis::visitOperation(
assert(llvm::find(op->getResults(), result) != op->result_end());
LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
- IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
- Optional<IntegerValueRange> oldRange;
- if (!lattice->isUninitialized())
- oldRange = lattice->getValue();
-
- ChangeResult changed = lattice->join(attrs);
-
- // Catch loop results with loop variant bounds and conservatively make
- // them [-inf, inf] so we don't circle around infinitely often (because
- // the dataflow analysis in MLIR doesn't attempt to work out trip counts
- // and often can't).
- bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
- return op->hasTrait<OpTrait::IsTerminator>();
- });
- if (isYieldedResult && oldRange.hasValue() &&
- !(lattice->getValue() == *oldRange)) {
- LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
- changed |= lattice->markPessimisticFixpoint();
- }
- propagateIfChanged(lattice, changed);
+ results[result.getResultNumber()]->update(
+ this, [&](IntegerValueRangeState *state) {
+ Optional<IntegerValueRange> oldRange;
+ if (!state->isUninitialized())
+ oldRange = state->getValue();
+
+ ChangeResult changed = state->join(attrs);
+
+ // Catch loop results with loop variant bounds and conservatively make
+ // them [-inf, inf] so we don't circle around infinitely often
+ // (because the dataflow analysis in MLIR doesn't attempt to work out
+ // trip counts and often can't).
+ bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
+ return op->hasTrait<OpTrait::IsTerminator>();
+ });
+ if (isYieldedResult && oldRange.hasValue() &&
+ !(state->getValue() == *oldRange)) {
+ LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ changed |= state->markPessimisticFixpoint();
+ }
+ return changed;
+ });
};
inferrable.inferResultRanges(argRanges, joinCallback);
@@ -114,7 +96,8 @@ void IntegerRangeAnalysis::visitOperation(
void IntegerRangeAnalysis::visitNonControlFlowArguments(
Operation *op, const RegionSuccessor &successor,
- ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
+ ArrayRef<IntegerValueRangeState::ElementT *> argLattices,
+ unsigned firstIndex) {
if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
SmallVector<ConstantIntRanges> argRanges(
@@ -131,25 +114,28 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
return;
LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
- IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
- Optional<IntegerValueRange> oldRange;
- if (!lattice->isUninitialized())
- oldRange = lattice->getValue();
-
- ChangeResult changed = lattice->join(attrs);
-
- // Catch loop results with loop variant bounds and conservatively make
- // them [-inf, inf] so we don't circle around infinitely often (because
- // the dataflow analysis in MLIR doesn't attempt to work out trip counts
- // and often can't).
- bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
- return op->hasTrait<OpTrait::IsTerminator>();
- });
- if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
- LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
- changed |= lattice->markPessimisticFixpoint();
- }
- propagateIfChanged(lattice, changed);
+ argLattices[arg.getArgNumber()]->update(
+ this, [&](IntegerValueRangeState *state) {
+ Optional<IntegerValueRange> oldRange;
+ if (!state->isUninitialized())
+ oldRange = state->getValue();
+
+ ChangeResult changed = state->join(attrs);
+
+ // Catch loop results with loop variant bounds and conservatively
+ // make them [-inf, inf] so we don't circle around infinitely often
+ // (because the dataflow analysis in MLIR doesn't attempt to work
+ // out trip counts and often can't).
+ bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
+ return op->hasTrait<OpTrait::IsTerminator>();
+ });
+ if (isYieldedValue && oldRange &&
+ !(state->getValue() == *oldRange)) {
+ LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ changed |= state->markPessimisticFixpoint();
+ }
+ return changed;
+ });
};
inferrable.inferResultRanges(argRanges, joinCallback);
@@ -168,11 +154,9 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
return bound.getValue();
} else if (auto value = loopBound->dyn_cast<Value>()) {
- const IntegerValueRangeLattice *lattice =
- getLatticeElementFor(op, value);
- if (lattice != nullptr)
- return getUpper ? lattice->getValue().getValue().smax()
- : lattice->getValue().getValue().smin();
+ const IntegerValueRangeState *state = getLatticeElementFor(op, value);
+ return getUpper ? state->getValue().getValue().smax()
+ : state->getValue().getValue().smin();
}
}
// Given the results of getConstant{Lower,Upper}Bound()
@@ -192,13 +176,10 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
Optional<OpFoldResult> step = loop.getSingleStep();
- APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
- /*getUpper=*/false);
- APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
- /*getUpper=*/true);
+ APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), false);
+ APInt max = getLoopBoundFromFold(upperBound, iv->getType(), true);
// Assume positivity for uniscoverable steps by way of getUpper = true.
- APInt stepVal =
- getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
+ APInt stepVal = getLoopBoundFromFold(step, iv->getType(), true);
if (stepVal.isNegative()) {
std::swap(min, max);
@@ -208,12 +189,53 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
max -= 1;
}
- IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
+ auto *ivEntry = getLatticeElement(*iv);
auto ivRange = ConstantIntRanges::fromSigned(min, max);
- propagateIfChanged(ivEntry, ivEntry->join(ivRange));
- return;
+ return ivEntry->update(this, [&ivRange](IntegerValueRangeState *state) {
+ return state->join(ivRange);
+ });
}
return SparseDataFlowAnalysis::visitNonControlFlowArguments(
op, successor, argLattices, firstIndex);
}
+
+LogicalResult IntegerRangeToConstant::initialize(Operation *top) {
+ auto visitValues = [this](ValueRange values) {
+ for (Value value : values)
+ (void)visit(value);
+ };
+ top->walk([&](Operation *op) {
+ visitValues(op->getResults());
+ for (Region ®ion : op->getRegions())
+ for (Block &block : region)
+ visitValues(block.getArguments());
+ });
+ return success();
+}
+
+LogicalResult IntegerRangeToConstant::visit(ProgramPoint point) {
+ auto value = point.get<Value>();
+ auto *rangeState = getOrCreateFor<IntegerValueRangeState>(value, value);
+ if (rangeState->isUninitialized())
+ return success();
+
+ update<ConstantValueState>(value, [&](ConstantValueState *state) {
+ const ConstantIntRanges &range = rangeState->getValue().getValue();
+ // Try to narrow to a constant.
+ Optional<APInt> constant = range.getConstantValue();
+ if (!constant)
+ return state->markPessimisticFixpoint();
+
+ // Find a dialect to materialize the constant.
+ Dialect *dialect;
+ if (Operation *op = value.getDefiningOp())
+ dialect = op->getDialect();
+ else
+ dialect = value.getParentRegion()->getParentOp()->getDialect();
+
+ Attribute attr = IntegerAttr::get(value.getType(), *constant);
+ return state->join(ConstantValue(attr, dialect));
+ });
+ return success();
+}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 4caa5eea326a..4a08efc7b6f1 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -12,17 +12,6 @@
using namespace mlir;
using namespace mlir::dataflow;
-//===----------------------------------------------------------------------===//
-// AbstractSparseLattice
-//===----------------------------------------------------------------------===//
-
-void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
- // Push all users of the value to the queue.
- for (Operation *user : point.get<Value>().getUsers())
- for (DataFlowAnalysis *analysis : useDefSubscribers)
- solver->enqueue({user, analysis});
-}
-
//===----------------------------------------------------------------------===//
// AbstractSparseDataFlowAnalysis
//===----------------------------------------------------------------------===//
@@ -80,28 +69,26 @@ void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
return;
// If the containing block is not executable, bail out.
- if (!getOrCreate<Executable>(op->getBlock())->isLive())
+ if (!getOrCreate<Executable>(op->getBlock())->get()->isLive())
return;
// Get the result lattices.
- SmallVector<AbstractSparseLattice *> resultLattices;
- resultLattices.reserve(op->getNumResults());
+ SmallVector<AbstractSparseElement *> resultElements;
+ resultElements.reserve(op->getNumResults());
// Track whether all results have reached their fixpoint.
bool allAtFixpoint = true;
for (Value result : op->getResults()) {
- AbstractSparseLattice *resultLattice = getLatticeElement(result);
- allAtFixpoint &= resultLattice->isAtFixpoint();
- resultLattices.push_back(resultLattice);
+ AbstractSparseElement *resultElement = getLatticeElement(result);
+ allAtFixpoint &= resultElement->get()->isAtFixpoint();
+ resultElements.push_back(resultElement);
}
// If all result lattices have reached a fixpoint, there is nothing to do.
if (allAtFixpoint)
return;
// The results of a region branch operation are determined by control-flow.
- if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- return visitRegionSuccessors({branch}, branch,
- /*successorIndex=*/llvm::None, resultLattices);
- }
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
+ return visitRegionSuccessors({branch}, branch, llvm::None, resultElements);
// The results of a call operation are determined by the callgraph.
if (auto call = dyn_cast<CallOpInterface>(op)) {
@@ -109,27 +96,27 @@ void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
// If not all return sites are known, then conservatively assume we can't
// reason about the data-flow.
if (!predecessors->allPredecessorsKnown())
- return markAllPessimisticFixpoint(resultLattices);
+ return markAllPessimisticFixpoint(resultElements);
for (Operation *predecessor : predecessors->getKnownPredecessors())
- for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
+ for (auto it : llvm::zip(predecessor->getOperands(), resultElements))
join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
return;
}
// Grab the lattice elements of the operands.
- SmallVector<const AbstractSparseLattice *> operandLattices;
- operandLattices.reserve(op->getNumOperands());
+ SmallVector<const AbstractSparseState *> operandStates;
+ operandStates.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
- AbstractSparseLattice *operandLattice = getLatticeElement(operand);
- operandLattice->useDefSubscribe(this);
+ AbstractSparseElement *operandElement = getLatticeElement(operand);
+ operandElement->useDefSubscribe(this);
// If any of the operand states are not initialized, bail out.
- if (operandLattice->isUninitialized())
+ if (operandElement->get()->isUninitialized())
return;
- operandLattices.push_back(operandLattice);
+ operandStates.push_back(operandElement->get());
}
// Invoke the operation transfer function.
- visitOperationImpl(op, operandLattices, resultLattices);
+ visitOperationImpl(op, operandStates, resultElements);
}
void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
@@ -138,17 +125,17 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
return;
// If the block is not executable, bail out.
- if (!getOrCreate<Executable>(block)->isLive())
+ if (!getOrCreate<Executable>(block)->get()->isLive())
return;
// Get the argument lattices.
- SmallVector<AbstractSparseLattice *> argLattices;
- argLattices.reserve(block->getNumArguments());
+ SmallVector<AbstractSparseElement *> argElements;
+ argElements.reserve(block->getNumArguments());
bool allAtFixpoint = true;
for (BlockArgument argument : block->getArguments()) {
- AbstractSparseLattice *argLattice = getLatticeElement(argument);
- allAtFixpoint &= argLattice->isAtFixpoint();
- argLattices.push_back(argLattice);
+ AbstractSparseElement *argElement = getLatticeElement(argument);
+ allAtFixpoint &= argElement->get()->isAtFixpoint();
+ argElements.push_back(argElement);
}
// If all argument lattices have reached their fixpoints, then there is
// nothing to do.
@@ -165,10 +152,10 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints.
if (!callsites->allPredecessorsKnown())
- return markAllPessimisticFixpoint(argLattices);
+ return markAllPessimisticFixpoint(argElements);
for (Operation *callsite : callsites->getKnownPredecessors()) {
auto call = cast<CallOpInterface>(callsite);
- for (auto it : llvm::zip(call.getArgOperands(), argLattices))
+ for (auto it : llvm::zip(call.getArgOperands(), argElements))
join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
}
return;
@@ -177,13 +164,13 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
return visitRegionSuccessors(
- block, branch, block->getParent()->getRegionNumber(), argLattices);
+ block, branch, block->getParent()->getRegionNumber(), argElements);
}
// Otherwise, we can't reason about the data-flow.
return visitNonControlFlowArgumentsImpl(block->getParentOp(),
RegionSuccessor(block->getParent()),
- argLattices, /*firstIndex=*/0);
+ argElements, /*firstIndex=*/0);
}
// Iterate over the predecessors of the non-entry block.
@@ -196,7 +183,7 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
auto *edgeExecutable =
getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
edgeExecutable->blockContentSubscribe(this);
- if (!edgeExecutable->isLive())
+ if (!edgeExecutable->get()->isLive())
continue;
// Check if we can reason about the data-flow from the predecessor.
@@ -204,7 +191,7 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
SuccessorOperands operands =
branch.getSuccessorOperands(it.getSuccessorIndex());
- for (auto &it : llvm::enumerate(argLattices)) {
+ for (auto &it : llvm::enumerate(argElements)) {
if (Value operand = operands[it.index()]) {
join(it.value(), *getLatticeElementFor(block, operand));
} else {
@@ -214,7 +201,7 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
}
}
} else {
- return markAllPessimisticFixpoint(argLattices);
+ return markAllPessimisticFixpoint(argElements);
}
}
}
@@ -222,7 +209,7 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint point, RegionBranchOpInterface branch,
Optional<unsigned> successorIndex,
- ArrayRef<AbstractSparseLattice *> lattices) {
+ ArrayRef<AbstractSparseElement *> elements) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
@@ -242,7 +229,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
if (!operands) {
// We can't reason about the data-flow.
- return markAllPessimisticFixpoint(lattices);
+ return markAllPessimisticFixpoint(elements);
}
ValueRange inputs = predecessors->getSuccessorInputs(op);
@@ -250,7 +237,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
"expected the same number of successor inputs as operands");
unsigned firstIndex = 0;
- if (inputs.size() != lattices.size()) {
+ if (inputs.size() != elements.size()) {
if (auto *op = point.dyn_cast<Operation *>()) {
if (!inputs.empty())
firstIndex = inputs.front().cast<OpResult>().getResultNumber();
@@ -258,7 +245,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
branch,
RegionSuccessor(
branch->getResults().slice(firstIndex, inputs.size())),
- lattices, firstIndex);
+ elements, firstIndex);
} else {
if (!inputs.empty())
firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
@@ -267,36 +254,42 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
branch,
RegionSuccessor(region, region->getArguments().slice(
firstIndex, inputs.size())),
- lattices, firstIndex);
+ elements, firstIndex);
}
}
- for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
+ for (auto it : llvm::zip(*operands, elements.drop_front(firstIndex)))
join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
}
}
-const AbstractSparseLattice *
+const AbstractSparseState *
AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
Value value) {
- AbstractSparseLattice *state = getLatticeElement(value);
- addDependency(state, point);
- return state;
+ AbstractSparseElement *element = getLatticeElement(value);
+ element->addDependency(this, point);
+ return element->get();
}
void AbstractSparseDataFlowAnalysis::markPessimisticFixpoint(
- AbstractSparseLattice *lattice) {
- propagateIfChanged(lattice, lattice->markPessimisticFixpoint());
+ AbstractSparseElement *element) {
+ element->update(this, [](AbstractState *state) {
+ return static_cast<AbstractSparseState *>(state)
+ ->markPessimisticFixpoint();
+ });
}
void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
- ArrayRef<AbstractSparseLattice *> lattices) {
- for (AbstractSparseLattice *lattice : lattices) {
- markPessimisticFixpoint(lattice);
- }
+ ArrayRef<AbstractSparseElement *> elements) {
+ for (AbstractSparseElement *element : elements)
+ markPessimisticFixpoint(element);
}
-void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
- const AbstractSparseLattice &rhs) {
- propagateIfChanged(lhs, lhs->join(rhs));
+void AbstractSparseDataFlowAnalysis::join(
+ AbstractSparseElement *lhs,
+ const AbstractSparseState &rhs) {
+ lhs->update(this, [&rhs](AbstractState *lhsState) {
+ return static_cast<AbstractSparseState *>(lhsState)->join(rhs);
+ });
}
+
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 18d9ba1bd5d6..0eca6263dc4e 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -24,12 +24,6 @@ using namespace mlir;
GenericProgramPoint::~GenericProgramPoint() = default;
-//===----------------------------------------------------------------------===//
-// AnalysisState
-//===----------------------------------------------------------------------===//
-
-AnalysisState::~AnalysisState() = default;
-
//===----------------------------------------------------------------------===//
// ProgramPoint
//===----------------------------------------------------------------------===//
@@ -58,6 +52,34 @@ Location ProgramPoint::getLoc() const {
return get<Block *>()->getParent()->getLoc();
}
+//===----------------------------------------------------------------------===//
+// AbstractState and AbstractElement
+//===----------------------------------------------------------------------===//
+
+AbstractState::~AbstractState() = default;
+AbstractElement::~AbstractElement() = default;
+
+void AbstractElement::addDependency(DataFlowAnalysis *analysis,
+ ProgramPoint point) {
+ auto inserted = dependents.insert({point, analysis});
+ (void)inserted;
+ DATAFLOW_DEBUG({
+ if (inserted) {
+ llvm::dbgs() << "Adding dependency from " << debugName << " of "
+ << this->point << " to " << analysis->debugName << " on "
+ << point << "\n";
+ }
+ });
+}
+
+void AbstractElement::propagateUpdate() {
+ DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << debugName << " of "
+ << point << "\nValue: " << *get() << "\n");
+ for (auto &item : dependents)
+ solver.enqueue(item);
+ onUpdate();
+}
+
//===----------------------------------------------------------------------===//
// DataFlowSolver
//===----------------------------------------------------------------------===//
@@ -94,30 +116,12 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
return success();
}
-void DataFlowSolver::propagateIfChanged(AnalysisState *state,
- ChangeResult changed) {
- if (changed == ChangeResult::Change) {
- DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
- << " of " << state->point << "\n"
- << "Value: " << *state << "\n");
- for (const WorkItem &item : state->dependents)
- enqueue(item);
- state->onUpdate(this);
- }
-}
-
-void DataFlowSolver::addDependency(AnalysisState *state,
- DataFlowAnalysis *analysis,
- ProgramPoint point) {
- auto inserted = state->dependents.insert({point, analysis});
- (void)inserted;
- DATAFLOW_DEBUG({
- if (inserted) {
- llvm::dbgs() << "Creating dependency between " << state->debugName
- << " of " << state->point << "\nand " << analysis->debugName
- << " on " << point << "\n";
- }
- });
+void DataFlowSolver::getStaticProvidersFor(
+ TypeID stateID, ProgramPoint point,
+ SmallVectorImpl<DataFlowAnalysis *> &staticProviders) const {
+ for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses))
+ if (analysis.staticallyProvides(stateID, point))
+ staticProviders.push_back(&analysis);
}
//===----------------------------------------------------------------------===//
@@ -127,12 +131,3 @@ void DataFlowSolver::addDependency(AnalysisState *state,
DataFlowAnalysis::~DataFlowAnalysis() = default;
DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
-
-void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) {
- solver.addDependency(state, this, point);
-}
-
-void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,
- ChangeResult changed) {
- solver.propagateIfChanged(state, changed);
-}
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
index 49d0ac70a604..9ce9cabdcbff 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
@@ -23,8 +23,8 @@ using namespace mlir::dataflow;
/// bound on its value (if it is treated as signed) and that bound is
/// non-negative.
static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
- auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
- if (!result)
+ auto *result = solver.lookup<IntegerValueRangeState>(v);
+ if (!result || result->isUninitialized())
return failure();
const ConstantIntRanges &range = result->getValue().getValue();
return success(range.smin().isNonNegative());
@@ -113,6 +113,7 @@ struct ArithmeticUnsignedWhenEquivalentPass
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<IntegerRangeAnalysis>();
+ solver.load<IntegerRangeToConstant>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 27cda9be50ab..b0e447b9b14d 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -37,7 +37,7 @@ using namespace mlir::dataflow;
static LogicalResult replaceWithConstant(DataFlowSolver &solver,
OpBuilder &builder,
OperationFolder &folder, Value value) {
- auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
+ auto *lattice = solver.lookup<ConstantValueState>(value);
if (!lattice || lattice->isUninitialized())
return failure();
const ConstantValue &latticeValue = lattice->getValue();
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
index 27e994cce3b6..d12e57e1f509 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
@@ -29,7 +29,7 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
os << " ";
block.printAsOperand(os);
os << " = ";
- auto *live = solver.lookupState<Executable>(&block);
+ auto *live = solver.lookup<Executable>(&block);
if (live)
os << *live;
else
@@ -39,7 +39,7 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
os << " from ";
pred->printAsOperand(os);
os << " = ";
- auto *live = solver.lookupState<Executable>(
+ auto *live = solver.lookup<Executable>(
solver.getProgramPoint<CFGEdge>(pred, &block));
if (live)
os << *live;
@@ -49,12 +49,12 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
}
}
if (!region.empty()) {
- auto *preds = solver.lookupState<PredecessorState>(®ion.front());
+ auto *preds = solver.lookup<PredecessorState>(®ion.front());
if (preds)
os << "region_preds: " << *preds << "\n";
}
}
- auto *preds = solver.lookupState<PredecessorState>(op);
+ auto *preds = solver.lookup<PredecessorState>(op);
if (preds)
os << "op_preds: " << *preds << "\n";
});
@@ -79,9 +79,10 @@ struct ConstantAnalysis : public DataFlowAnalysis {
Operation *op = point.get<Operation *>();
Attribute value;
if (matchPattern(op, m_Constant(&value))) {
- auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
- propagateIfChanged(
- constant, constant->join(ConstantValue(value, op->getDialect())));
+ auto *constant = getOrCreate<ConstantValueState>(op->getResult(0));
+ constant->update(this, [value, op](ConstantValueState *state) {
+ return state->join(ConstantValue(value, op->getDialect()));
+ });
return success();
}
markAllPessimisticFixpoint(op->getResults());
@@ -94,9 +95,10 @@ struct ConstantAnalysis : public DataFlowAnalysis {
/// pessimistic fixpoint.
void markAllPessimisticFixpoint(ValueRange values) {
for (Value value : values) {
- auto *constantValue = getOrCreate<Lattice<ConstantValue>>(value);
- propagateIfChanged(constantValue,
- constantValue->markPessimisticFixpoint());
+ auto *constant = getOrCreate<ConstantValueState>(value);
+ constant->update(this, [](ConstantValueState *state) {
+ return state->markPessimisticFixpoint();
+ });
}
}
};
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
index 98ccb6f94747..3dc63146a591 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
@@ -20,9 +20,7 @@ namespace {
class UnderlyingValue {
public:
/// The pessimistic underlying value of a value is itself.
- static UnderlyingValue getPessimisticValueState(Value value) {
- return {value};
- }
+ static UnderlyingValue getPessimisticValue(Value value) { return {value}; }
/// Create an underlying value state with a known underlying value.
UnderlyingValue(Value underlyingValue = {})
@@ -51,21 +49,21 @@ class UnderlyingValue {
/// This lattice represents, for a given memory resource, the potential last
/// operations that modified the resource.
-class LastModification : public AbstractDenseLattice {
+class LastModification : public AbstractDenseState {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
- using AbstractDenseLattice::AbstractDenseLattice;
+ using ElementT =
+ SingleStateElement<LastModification, AbstractDenseElement>;
+
+ explicit LastModification(ProgramPoint point) {}
/// The lattice is always initialized.
bool isUninitialized() const override { return false; }
- /// Initialize the lattice. Does nothing.
- ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
-
/// Mark the lattice as having reached its pessimistic fixpoint. That is, the
/// last modifications of all memory resources are unknown.
- ChangeResult reset() override {
+ ChangeResult markPessimisticFixpoint() override {
if (lastMods.empty())
return ChangeResult::NoChange;
lastMods.clear();
@@ -76,7 +74,7 @@ class LastModification : public AbstractDenseLattice {
bool isAtFixpoint() const override { return false; }
/// Join the last modifications.
- ChangeResult join(const AbstractDenseLattice &lattice) override {
+ ChangeResult join(const AbstractDenseState &lattice) override {
const auto &rhs = static_cast<const LastModification &>(lattice);
ChangeResult result = ChangeResult::NoChange;
for (const auto &mod : rhs.lastMods) {
@@ -135,13 +133,16 @@ class LastModifiedAnalysis : public DenseDataFlowAnalysis<LastModification> {
/// its reaching definitions is set to empty. If the operation writes to a
/// resource, then its reaching definition is set to the written value.
void visitOperation(Operation *op, const LastModification &before,
- LastModification *after) override;
+ LastModification::ElementT *after) override;
};
/// Define the lattice class explicitly to provide a type ID.
-struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice)
- using Lattice::Lattice;
+struct UnderlyingValueState : public OptimisticSparseState<UnderlyingValue> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueState)
+
+ using OptimisticSparseState::OptimisticSparseState;
+ using ElementT =
+ SparseElement<UnderlyingValueState, SingleStateElement>;
};
/// An analysis that uses forwarding of values along control-flow and callgraph
@@ -149,14 +150,14 @@ struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
/// analysis exists so that the test analysis and pass can test the behaviour of
/// the dense data-flow analysis on the callgraph.
class UnderlyingValueAnalysis
- : public SparseDataFlowAnalysis<UnderlyingValueLattice> {
+ : public SparseDataFlowAnalysis<UnderlyingValueState> {
public:
using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
/// The underlying value of the results of an operation are not known.
- void visitOperation(Operation *op,
- ArrayRef<const UnderlyingValueLattice *> operands,
- ArrayRef<UnderlyingValueLattice *> results) override {
+ void
+ visitOperation(Operation *op, ArrayRef<const UnderlyingValueState *> operands,
+ ArrayRef<UnderlyingValueState::ElementT *> results) override {
markAllPessimisticFixpoint(results);
}
};
@@ -165,8 +166,8 @@ class UnderlyingValueAnalysis
/// Look for the most underlying value of a value.
static Value getMostUnderlyingValue(
Value value,
- function_ref<const UnderlyingValueLattice *(Value)> getUnderlyingValueFn) {
- const UnderlyingValueLattice *underlying;
+ function_ref<const UnderlyingValueState *(Value)> getUnderlyingValueFn) {
+ const UnderlyingValueState *underlying;
do {
underlying = getUnderlyingValueFn(value);
if (!underlying || underlying->isUninitialized())
@@ -181,38 +182,42 @@ static Value getMostUnderlyingValue(
void LastModifiedAnalysis::visitOperation(Operation *op,
const LastModification &before,
- LastModification *after) {
+ LastModification::ElementT *after) {
auto memory = dyn_cast<MemoryEffectOpInterface>(op);
// If we can't reason about the memory effects, then conservatively assume we
// can't deduce anything about the last modifications.
if (!memory)
- return reset(after);
+ return markPessimisticFixpoint(after);
SmallVector<MemoryEffects::EffectInstance> effects;
memory.getEffects(effects);
- ChangeResult result = after->join(before);
- for (const auto &effect : effects) {
- Value value = effect.getValue();
+ after->update(this, [&](LastModification *state) {
+ ChangeResult result = state->join(before);
+ for (const auto &effect : effects) {
+ Value value = effect.getValue();
- // If we see an effect on anything other than a value, assume we can't
- // deduce anything about the last modifications.
- if (!value)
- return reset(after);
+ // If we see an effect on anything other than a value, assume we can't
+ // deduce anything about the last modifications.
+ if (!value) {
+ result |= state->markPessimisticFixpoint();
+ break;
+ }
- value = getMostUnderlyingValue(value, [&](Value value) {
- return getOrCreateFor<UnderlyingValueLattice>(op, value);
- });
- if (!value)
- return;
+ value = getMostUnderlyingValue(value, [&](Value value) {
+ return getOrCreateFor<UnderlyingValueState>(op, value);
+ });
+ if (!value)
+ return ChangeResult::NoChange;
- // Nothing to do for reads.
- if (isa<MemoryEffects::Read>(effect.getEffect()))
- continue;
+ // Nothing to do for reads.
+ if (isa<MemoryEffects::Read>(effect.getEffect()))
+ continue;
- result |= after->set(value, op);
- }
- propagateIfChanged(after, result);
+ result |= state->set(value, op);
+ }
+ return result;
+ });
}
namespace {
@@ -240,13 +245,12 @@ struct TestLastModifiedPass
if (!tag)
return;
os << "test_tag: " << tag.getValue() << ":\n";
- const LastModification *lastMods =
- solver.lookupState<LastModification>(op);
+ const auto *lastMods = solver.lookup<LastModification>(op);
assert(lastMods && "expected a dense lattice");
for (auto &it : llvm::enumerate(op->getOperands())) {
os << " operand #" << it.index() << "\n";
Value value = getMostUnderlyingValue(it.value(), [&](Value value) {
- return solver.lookupState<UnderlyingValueLattice>(value);
+ return solver.lookup<UnderlyingValueState>(value);
});
assert(value && "expected an underlying value");
if (Optional<ArrayRef<Operation *>> lastMod =
diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 329be3c5446f..2fd41b25b765 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -14,17 +14,15 @@ using namespace mlir;
namespace {
/// This analysis state represents an integer that is XOR'd with other states.
-class FooState : public AnalysisState {
+class FooState : public AbstractState {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
- using AnalysisState::AnalysisState;
+ using ElementT = SingleStateElement<FooState>;
- /// Default-initialize the state to zero.
- ChangeResult defaultInitialize() override { return join(0); }
+ explicit FooState(ProgramPoint point) {}
- /// Returns true if the state is uninitialized.
- bool isUninitialized() const override { return !state; }
+ bool isUninitialized() const { return !state; }
/// Print the integer value or "none" if uninitialized.
void print(raw_ostream &os) const override {
@@ -99,7 +97,8 @@ LogicalResult FooAnalysis::initialize(Operation *top) {
return top->emitError("expected a single region top-level op");
// Initialize the top-level state.
- getOrCreate<FooState>(&top->getRegion(0).front())->join(0);
+ update<FooState>(&top->getRegion(0).front(),
+ [](FooState *state) { return state->join(0); });
// Visit all nested blocks and operations.
for (Block &block : top->getRegion(0)) {
@@ -130,35 +129,37 @@ void FooAnalysis::visitBlock(Block *block) {
// This is the initial state. Let the framework default-initialize it.
return;
}
- FooState *state = getOrCreate<FooState>(block);
- ChangeResult result = ChangeResult::NoChange;
- for (Block *pred : block->getPredecessors()) {
- // Join the state at the terminators of all predecessors.
- const FooState *predState =
- getOrCreateFor<FooState>(block, pred->getTerminator());
- result |= state->join(*predState);
- }
- propagateIfChanged(state, result);
+ update<FooState>(block, [&](FooState *state) {
+ ChangeResult result = ChangeResult::NoChange;
+ for (Block *pred : block->getPredecessors()) {
+ // Join the state at the terminators of all predecessors.
+ const FooState *predState =
+ getOrCreateFor<FooState>(block, pred->getTerminator());
+ result |= state->join(*predState);
+ }
+ return result;
+ });
}
void FooAnalysis::visitOperation(Operation *op) {
- FooState *state = getOrCreate<FooState>(op);
- ChangeResult result = ChangeResult::NoChange;
-
- // Copy the state across the operation.
- const FooState *prevState;
- if (Operation *prev = op->getPrevNode())
- prevState = getOrCreateFor<FooState>(op, prev);
- else
- prevState = getOrCreateFor<FooState>(op, op->getBlock());
- result |= state->set(*prevState);
-
- // Modify the state with the attribute, if specified.
- if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
- uint64_t value = attr.getUInt();
- result |= state->join(value);
- }
- propagateIfChanged(state, result);
+ update<FooState>(op, [&](FooState *state) {
+ ChangeResult result = ChangeResult::NoChange;
+
+ // Copy the state across the operation.
+ const FooState *prevState;
+ if (Operation *prev = op->getPrevNode())
+ prevState = getOrCreateFor<FooState>(op, prev);
+ else
+ prevState = getOrCreateFor<FooState>(op, op->getBlock());
+ result |= state->set(*prevState);
+
+ // Modify the state with the attribute, if specified.
+ if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
+ uint64_t value = attr.getUInt();
+ result |= state->join(value);
+ }
+ return result;
+ });
}
void TestFooAnalysisPass::runOnOperation() {
@@ -175,7 +176,7 @@ void TestFooAnalysisPass::runOnOperation() {
auto tag = op->getAttrOfType<StringAttr>("tag");
if (!tag)
return;
- const FooState *state = solver.lookupState<FooState>(op);
+ const FooState *state = solver.lookup<FooState>(op);
assert(state && !state->isUninitialized());
os << tag.getValue() << " -> " << state->getValue() << "\n";
});
diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
index 05e908aae311..f5bfff8e6e1a 100644
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
@@ -23,23 +23,15 @@ using namespace mlir::dataflow;
/// Patterned after SCCP
static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
OperationFolder &folder, Value value) {
- auto *maybeInferredRange =
- solver.lookupState<IntegerValueRangeLattice>(value);
- if (!maybeInferredRange || maybeInferredRange->isUninitialized())
- return failure();
- const ConstantIntRanges &inferredRange =
- maybeInferredRange->getValue().getValue();
- Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
- if (!maybeConstValue.hasValue())
+ auto *constantState = solver.lookup<ConstantValueState>(value);
+ if (!constantState || constantState->isUninitialized() ||
+ !constantState->getValue().getConstantValue())
return failure();
- Operation *maybeDefiningOp = value.getDefiningOp();
- Dialect *valueDialect =
- maybeDefiningOp ? maybeDefiningOp->getDialect()
- : value.getParentRegion()->getParentOp()->getDialect();
- Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
- Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
- value.getType(), value.getLoc());
+ const ConstantValue &constantValue = constantState->getValue();
+ Value constant = folder.getOrCreateConstant(
+ b, constantValue.getConstantDialect(), constantValue.getConstantValue(),
+ value.getType(), value.getLoc());
if (!constant)
return failure();
@@ -106,6 +98,7 @@ struct TestIntRangeInference
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<IntegerRangeAnalysis>();
+ solver.load<IntegerRangeToConstant>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
rewrite(solver, op->getContext(), op->getRegions());
More information about the llvm-branch-commits
mailing list