[llvm-branch-commits] [mlir] 1ed1e8c - overhaul state management and allow multi state elements

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jun 29 17:03:51 PDT 2022


Author: Mogball
Date: 2022-06-29T17:03:09-07:00
New Revision: 1ed1e8c7843248091019d5e902b73405567dc464

URL: https://github.com/llvm/llvm-project/commit/1ed1e8c7843248091019d5e902b73405567dc464
DIFF: https://github.com/llvm/llvm-project/commit/1ed1e8c7843248091019d5e902b73405567dc464.diff

LOG: overhaul state management and allow multi state elements

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
    mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
    mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
    mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
    mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
    mlir/include/mlir/Analysis/DataFlowFramework.h
    mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
    mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
    mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
    mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
    mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
    mlir/lib/Analysis/DataFlowFramework.cpp
    mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
    mlir/lib/Transforms/SCCP.cpp
    mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
    mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
    mlir/test/lib/Analysis/TestDataFlowFramework.cpp
    mlir/test/lib/Transforms/TestIntRangeInference.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
index 8e4baea08ab1..518e5ac23b39 100644
--- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
@@ -29,7 +29,7 @@ namespace dataflow {
 class ConstantValue {
 public:
   /// The pessimistic value state of the constant value is unknown.
-  static ConstantValue getPessimisticValueState(Value value) { return {}; }
+  static ConstantValue getPessimisticValue(Value value) { return {}; }
 
   /// Construct a constant value with a known constant.
   ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr)
@@ -53,6 +53,14 @@ class ConstantValue {
     return lhs == rhs ? lhs : ConstantValue();
   }
 
+  static ConstantValue meet(const ConstantValue &lhs,
+                            const ConstantValue &rhs) {
+    if (lhs == rhs) return lhs;
+    if (!lhs.constant) return rhs;
+    if (!rhs.constant) return lhs;
+    return ConstantValue();
+  }
+
   /// Print the constant value.
   void print(raw_ostream &os) const;
 
@@ -63,6 +71,13 @@ class ConstantValue {
   Dialect *dialect;
 };
 
+class ConstantValueState : public OptimisticSparseState<ConstantValue> {
+public:
+  using OptimisticSparseState::OptimisticSparseState;
+  using ElementT =
+      SparseElement<ConstantValueState, MultiStateElement>;
+};
+
 //===----------------------------------------------------------------------===//
 // SparseConstantPropagation
 //===----------------------------------------------------------------------===//
@@ -72,13 +87,13 @@ class ConstantValue {
 /// operands, by speculatively folding operations. When combined with dead-code
 /// analysis, this becomes sparse conditional constant propagation (SCCP).
 class SparseConstantPropagation
-    : public SparseDataFlowAnalysis<Lattice<ConstantValue>> {
+    : public SparseDataFlowAnalysis<ConstantValueState> {
 public:
   using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
 
-  void visitOperation(Operation *op,
-                      ArrayRef<const Lattice<ConstantValue> *> operands,
-                      ArrayRef<Lattice<ConstantValue> *> results) override;
+  void
+  visitOperation(Operation *op, ArrayRef<const ConstantValueState *> operands,
+                 ArrayRef<ConstantValueState::ElementT *> results) override;
 };
 
 } // end namespace dataflow

diff  --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 615c7eece33a..6c8dca49af55 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -24,21 +24,77 @@
 namespace mlir {
 namespace dataflow {
 
+//===----------------------------------------------------------------------===//
+// CFGEdge
+//===----------------------------------------------------------------------===//
+
+/// This program point represents a control-flow edge between a block and one
+/// of its successors.
+class CFGEdge
+    : public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
+public:
+  using Base::Base;
+
+  /// Get the block from which the edge originates.
+  Block *getFrom() const { return getValue().first; }
+  /// Get the target block.
+  Block *getTo() const { return getValue().second; }
+
+  /// Print the blocks between the control-flow edge.
+  void print(raw_ostream &os) const override;
+  /// Get a fused location of both blocks.
+  Location getLoc() const override;
+};
+
 //===----------------------------------------------------------------------===//
 // Executable
 //===----------------------------------------------------------------------===//
 
 /// This is a simple analysis state that represents whether the associated
 /// program point (either a block or a control-flow edge) is live.
-class Executable : public AnalysisState {
+class Executable : public AbstractState {
 public:
-  using AnalysisState::AnalysisState;
-
-  /// The state is initialized by default.
-  bool isUninitialized() const override { return false; }
-
-  /// The state is always initialized.
-  ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
+  template <typename ExecutableT>
+  class Element : public SingleStateElement<ExecutableT> {
+  public:
+    using SingleStateElement<ExecutableT>::SingleStateElement;
+
+    /// When the state of the program point is changed to live, re-invoke
+    /// subscribed analyses on the operations in the block and on the block
+    /// itself.
+    void onUpdate() override {
+      if (auto *block = this->point.template dyn_cast<Block *>()) {
+        // Re-invoke the analyses on the block itself.
+        for (DataFlowAnalysis *analysis : subscribers)
+          this->solver.enqueue({block, analysis});
+        // Re-invoke the analyses on all operations in the block.
+        for (DataFlowAnalysis *analysis : subscribers)
+          for (Operation &op : *block)
+            this->solver.enqueue({&op, analysis});
+      } else if (auto *programPoint =
+                     this->point.template dyn_cast<GenericProgramPoint *>()) {
+        // Re-invoke the analysis on the successor block.
+        if (auto *edge = dyn_cast<CFGEdge>(programPoint))
+          for (DataFlowAnalysis *analysis : subscribers)
+            this->solver.enqueue({edge->getTo(), analysis});
+      }
+    }
+
+    /// Subscribe an analysis to changes to the liveness.
+    void blockContentSubscribe(DataFlowAnalysis *analysis) {
+      subscribers.insert(analysis);
+    }
+
+  private:
+    /// A set of analyses that should be updated when this state changes.
+    SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
+              SmallPtrSet<DataFlowAnalysis *, 4>>
+        subscribers;
+  };
+  using ElementT = Element<Executable>;
+
+  /// Optimistically assume the program point is dead.
+  explicit Executable(ProgramPoint point) : live(false) {}
 
   /// Set the state of the program point to live.
   ChangeResult setToLive();
@@ -46,28 +102,12 @@ class Executable : public AnalysisState {
   /// Get whether the program point is live.
   bool isLive() const { return live; }
 
-  /// Print the liveness.
+  /// Print the liveness;
   void print(raw_ostream &os) const override;
 
-  /// When the state of the program point is changed to live, re-invoke
-  /// subscribed analyses on the operations in the block and on the block
-  /// itself.
-  void onUpdate(DataFlowSolver *solver) const override;
-
-  /// Subscribe an analysis to changes to the liveness.
-  void blockContentSubscribe(DataFlowAnalysis *analysis) {
-    subscribers.insert(analysis);
-  }
-
 private:
-  /// Whether the program point is live. Optimistically assume that the program
-  /// point is dead.
-  bool live = false;
-
-  /// A set of analyses that should be updated when this state changes.
-  SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
-            SmallPtrSet<DataFlowAnalysis *, 4>>
-      subscribers;
+  /// Whether the program point is live.
+  bool live;
 };
 
 //===----------------------------------------------------------------------===//
@@ -90,15 +130,11 @@ class Executable : public AnalysisState {
 ///
 /// The state can indicate that it is underdefined, meaning that not all live
 /// control-flow predecessors can be known.
-class PredecessorState : public AnalysisState {
+class PredecessorState : public AbstractState {
 public:
-  using AnalysisState::AnalysisState;
+  using ElementT = SingleStateElement<PredecessorState>;
 
-  /// The state is initialized by default.
-  bool isUninitialized() const override { return false; }
-
-  /// The state is always initialized.
-  ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
+  explicit PredecessorState(ProgramPoint point) {}
 
   /// Print the known predecessors.
   void print(raw_ostream &os) const override;
@@ -142,28 +178,6 @@ class PredecessorState : public AnalysisState {
   DenseMap<Operation *, ValueRange> successorInputs;
 };
 
-//===----------------------------------------------------------------------===//
-// CFGEdge
-//===----------------------------------------------------------------------===//
-
-/// This program point represents a control-flow edge between a block and one
-/// of its successors.
-class CFGEdge
-    : public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
-public:
-  using Base::Base;
-
-  /// Get the block from which the edge originates.
-  Block *getFrom() const { return getValue().first; }
-  /// Get the target block.
-  Block *getTo() const { return getValue().second; }
-
-  /// Print the blocks between the control-flow edge.
-  void print(raw_ostream &os) const override;
-  /// Get a fused location of both blocks.
-  Location getLoc() const override;
-};
-
 //===----------------------------------------------------------------------===//
 // DeadCodeAnalysis
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 01b77d958519..96697307a49d 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -16,33 +16,26 @@
 #define MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H
 
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 
 namespace mlir {
 namespace dataflow {
 
 //===----------------------------------------------------------------------===//
-// AbstractDenseLattice
+// AbstractDenseState
 //===----------------------------------------------------------------------===//
 
 /// This class represents a dense lattice. A dense lattice is attached to
 /// operations to represent the program state after their execution or to blocks
 /// to represent the program state at the beginning of the block. A dense
 /// lattice is propagated through the IR by dense data-flow analysis.
-class AbstractDenseLattice : public AnalysisState {
-public:
-  /// A dense lattice can only be created for operations and blocks.
-  using AnalysisState::AnalysisState;
-
-  /// Join the lattice across control-flow or callgraph edges.
-  virtual ChangeResult join(const AbstractDenseLattice &rhs) = 0;
+using AbstractDenseState = AbstractSparseState;
 
-  /// Reset the dense lattice to a pessimistic value. This occurs when the
-  /// analysis cannot reason about the data-flow.
-  virtual ChangeResult reset() = 0;
+class AbstractDenseElement : public AbstractElement {
+public:
+  using AbstractElement::AbstractElement;
 
-  /// Returns true if the lattice state has reached a pessimistic fixpoint. That
-  /// is, no further modifications to the lattice can occur.
-  virtual bool isAtFixpoint() const = 0;
+  virtual const AbstractDenseState *get() const override = 0;
 };
 
 //===----------------------------------------------------------------------===//
@@ -78,26 +71,29 @@ class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
   /// Propagate the dense lattice before the execution of an operation to the
   /// lattice after its execution.
   virtual void visitOperationImpl(Operation *op,
-                                  const AbstractDenseLattice &before,
-                                  AbstractDenseLattice *after) = 0;
+                                  const AbstractDenseState &before,
+                                  AbstractDenseElement *after) = 0;
 
-  /// Get the dense lattice after the execution of the given program point.
-  virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0;
+  /// Get the dense element after the execution of the given program point.
+  virtual AbstractDenseElement *getLattice(ProgramPoint point) = 0;
 
   /// Get the dense lattice after the execution of the given program point and
   /// add it as a dependency to a program point.
-  const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
-                                            ProgramPoint point);
-
-  /// Mark the dense lattice as having reached its pessimistic fixpoint and
-  /// propagate an update if it changed.
-  void reset(AbstractDenseLattice *lattice) {
-    propagateIfChanged(lattice, lattice->reset());
+  const AbstractDenseState *getLatticeFor(ProgramPoint dependent,
+                                          ProgramPoint point);
+
+  void update(AbstractDenseElement *element,
+              function_ref<ChangeResult(AbstractDenseState *)> updateFn) {
+    element->update(this, [updateFn](AbstractState *state) {
+      return updateFn(static_cast<AbstractDenseState *>(state));
+    });
   }
 
-  /// Join a lattice with another and propagate an update if it changed.
-  void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs) {
-    propagateIfChanged(lhs, lhs->join(rhs));
+  void markPessimisticFixpoint(AbstractDenseElement *element) {
+    element->update(this, [](AbstractState *state) {
+      return static_cast<AbstractDenseState *>(state)
+          ->markPessimisticFixpoint();
+    });
   }
 
 private:
@@ -116,7 +112,7 @@ class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
   /// parent operation itself.
   void visitRegionBranchOperation(ProgramPoint point,
                                   RegionBranchOpInterface branch,
-                                  AbstractDenseLattice *after);
+                                  AbstractDenseElement *after);
 };
 
 //===----------------------------------------------------------------------===//
@@ -128,29 +124,29 @@ class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
 /// transfer functions for operations.
 ///
 /// `StateT` is expected to be a subclass of `AbstractDenseLattice`.
-template <typename LatticeT>
+template <typename StateT>
 class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis {
 public:
   using AbstractDenseDataFlowAnalysis::AbstractDenseDataFlowAnalysis;
 
   /// Visit an operation with the dense lattice before its execution. This
   /// function is expected to set the dense lattice after its execution.
-  virtual void visitOperation(Operation *op, const LatticeT &before,
-                              LatticeT *after) = 0;
+  virtual void visitOperation(Operation *op, const StateT &before,
+                              typename StateT::ElementT *after) = 0;
 
 protected:
   /// Get the dense lattice after this program point.
-  LatticeT *getLattice(ProgramPoint point) override {
-    return getOrCreate<LatticeT>(point);
+  typename StateT::ElementT *getLattice(ProgramPoint point) override {
+    return getOrCreate<StateT>(point);
   }
 
 private:
   /// Type-erased wrappers that convert the abstract dense lattice to a derived
   /// lattice and invoke the virtual hooks operating on the derived lattice.
-  void visitOperationImpl(Operation *op, const AbstractDenseLattice &before,
-                          AbstractDenseLattice *after) override {
-    visitOperation(op, static_cast<const LatticeT &>(before),
-                   static_cast<LatticeT *>(after));
+  void visitOperationImpl(Operation *op, const AbstractDenseState &before,
+                          AbstractDenseElement *after) override {
+    visitOperation(op, static_cast<const StateT &>(before),
+                   static_cast<typename StateT::ElementT *>(after));
   }
 };
 

diff  --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 3cd007ab478b..2473bf27f55f 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -16,6 +16,7 @@
 #define MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H
 
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 
 namespace mlir {
@@ -27,7 +28,7 @@ class IntegerValueRange {
   /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
   /// range that is used to mark the value as unable to be analyzed further,
   /// where `t` is the type of `value`.
-  static IntegerValueRange getPessimisticValueState(Value value);
+  static IntegerValueRange getPessimisticValue(Value value);
 
   /// Create an integer value range lattice value.
   IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
@@ -45,6 +46,10 @@ class IntegerValueRange {
                                 const IntegerValueRange &rhs) {
     return lhs.value.rangeUnion(rhs.value);
   }
+  static IntegerValueRange meet(const IntegerValueRange &lhs,
+                                const IntegerValueRange &rhs) {
+    return lhs.value.intersection(rhs.value);
+  }
 
   /// Print the integer value range.
   void print(raw_ostream &os) const { os << value; }
@@ -57,38 +62,49 @@ class IntegerValueRange {
 /// This lattice element represents the integer value range of an SSA value.
 /// When this lattice is updated, it automatically updates the constant value
 /// of the SSA value (if the range can be narrowed to one).
-class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
+class IntegerValueRangeState : public OptimisticSparseState<IntegerValueRange> {
 public:
-  using Lattice::Lattice;
-
-  /// If the range can be narrowed to an integer constant, update the constant
-  /// value of the SSA value.
-  void onUpdate(DataFlowSolver *solver) const override;
+  using OptimisticSparseState::OptimisticSparseState;
+  using ElementT =
+      SparseElement<IntegerValueRangeState, SingleStateElement>;
 };
 
 /// Integer range analysis determines the integer value range of SSA values
 /// using operations that define `InferIntRangeInterface` and also sets the
 /// range of iteration indices of loops with known bounds.
 class IntegerRangeAnalysis
-    : public SparseDataFlowAnalysis<IntegerValueRangeLattice> {
+    : public SparseDataFlowAnalysis<IntegerValueRangeState> {
 public:
   using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
 
   /// Visit an operation. Invoke the transfer function on each operation that
   /// implements `InferIntRangeInterface`.
-  void visitOperation(Operation *op,
-                      ArrayRef<const IntegerValueRangeLattice *> operands,
-                      ArrayRef<IntegerValueRangeLattice *> results) override;
+  void
+  visitOperation(Operation *op,
+                 ArrayRef<const IntegerValueRangeState *> operands,
+                 ArrayRef<IntegerValueRangeState::ElementT *> results) override;
 
   /// Visit block arguments or operation results of an operation with region
   /// control-flow for which values are not defined by region control-flow. This
   /// function calls `InferIntRangeInterface` to provide values for block
   /// arguments or tries to reduce the range on loop induction variables with
   /// known bounds.
-  void
-  visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor,
-                               ArrayRef<IntegerValueRangeLattice *> argLattices,
-                               unsigned firstIndex) override;
+  void visitNonControlFlowArguments(
+      Operation *op, const RegionSuccessor &successor,
+      ArrayRef<IntegerValueRangeState::ElementT *> argLattices,
+      unsigned firstIndex) override;
+};
+
+class IntegerRangeToConstant : public DataFlowAnalysis {
+public:
+  using DataFlowAnalysis::DataFlowAnalysis;
+
+  LogicalResult initialize(Operation *top) override;
+  LogicalResult visit(ProgramPoint point) override;
+
+  bool staticallyProvides(TypeID stateID, ProgramPoint point) const override {
+    return stateID == TypeID::get<ConstantValueState>() && point.is<Value>();
+  }
 };
 
 } // end namespace dataflow

diff  --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 27af1bc46b0d..13c5abf4acdb 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -23,67 +23,87 @@ namespace mlir {
 namespace dataflow {
 
 //===----------------------------------------------------------------------===//
-// AbstractSparseLattice
+// AbstractSparseState
 //===----------------------------------------------------------------------===//
 
-/// This class represents an abstract lattice. A lattice contains information
-/// about an SSA value and is what's propagated across the IR by sparse
-/// data-flow analysis.
-class AbstractSparseLattice : public AnalysisState {
+class AbstractSparseState : public AbstractState {
 public:
-  /// Lattices can only be created for values.
-  AbstractSparseLattice(Value value) : AnalysisState(value) {}
+  /// Join the information contained in 'rhs' into this state. Returns
+  /// if the value of the state changed.
+  virtual ChangeResult join(const AbstractSparseState &rhs) = 0;
 
-  /// Join the information contained in 'rhs' into this lattice. Returns
-  /// if the value of the lattice changed.
-  virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0;
-
-  /// Returns true if the lattice element is at fixpoint and further calls to
-  /// `join` will not update the value of the element.
+  /// Returns true if the lattice state is at fixpoint and further calls to
+  /// `join` will not update the value of the state.
   virtual bool isAtFixpoint() const = 0;
 
-  /// Mark the lattice element as having reached a pessimistic fixpoint. This
-  /// means that the lattice may potentially have conflicting value states, and
-  /// only the most conservative value should be relied on.
+  /// Mark the lattice state as having reached a pessimistic fixpoint. This
+  /// means that the lattice may potentially have an overdefined or underdefined
+  /// value state, and only the most conservative value should be relied on.
   virtual ChangeResult markPessimisticFixpoint() = 0;
 
-  /// When the lattice gets updated, propagate an update to users of the value
-  /// using its use-def chain to subscribed analyses.
-  void onUpdate(DataFlowSolver *solver) const override;
+  /// Returns true if the value of this lattice hasn't yet been initialized.
+  virtual bool isUninitialized() const = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// AbstractSparseElement
+//===----------------------------------------------------------------------===//
+
+class AbstractSparseElement : public AbstractElement {
+public:
+  /// Sparse elements can only be created on SSA values.
+  explicit AbstractSparseElement(DataFlowSolver &solver, Value value)
+      : AbstractElement(solver, value) {}
 
-  /// Subscribe an analysis to updates of the lattice. When the lattice changes,
-  /// subscribed analyses are re-invoked on all users of the value. This is
-  /// more efficient than relying on the dependency map.
-  void useDefSubscribe(DataFlowAnalysis *analysis) {
+  virtual void useDefSubscribe(DataFlowAnalysis *analysis) = 0;
+
+  virtual const AbstractSparseState *get() const override = 0;
+};
+
+/// This class represents a sparse analysis element. A sparse element is
+/// attached to an SSA value and can track its dependents through the value's
+/// use-def chain. This is useful for improving the performance of sparse
+/// analyses where users are always dependents of SSA value elements.
+template <typename StateT, template <typename, typename> class BaseT>
+class SparseElement : public BaseT<StateT, AbstractSparseElement> {
+public:
+  using BaseT<StateT, AbstractSparseElement>::BaseT;
+
+  /// When the sparse element gets updated, propagate an update to users of the
+  /// value using its use-def chain to subscribed analyses.
+  void onUpdate() override {
+    for (Operation *user : this->point.template get<Value>().getUsers())
+      for (DataFlowAnalysis *analysis : useDefSubscribers)
+        this->solver.enqueue({user, analysis});
+  }
+
+  /// Subscribe an analysis to updates of the sparse element. When the element
+  /// changes, subscribed analyses are re-invoked on all users of the value.
+  /// This is more efficient than relying on the dependency map.
+  void useDefSubscribe(DataFlowAnalysis *analysis) override {
     useDefSubscribers.insert(analysis);
   }
 
 private:
-  /// A set of analyses that should be updated when this lattice changes.
+  /// A set of analyses that should be updated when this element changes.
   SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
             SmallPtrSet<DataFlowAnalysis *, 4>>
       useDefSubscribers;
 };
 
 //===----------------------------------------------------------------------===//
-// Lattice
+// OptimisticSparseState
 //===----------------------------------------------------------------------===//
 
-/// This class represents a lattice holding a specific value of type `ValueT`.
-/// Lattice values (`ValueT`) are required to adhere to the following:
-///
-///   * static ValueT join(const ValueT &lhs, const ValueT &rhs);
-///     - This method conservatively joins the information held by `lhs`
-///       and `rhs` into a new value. This method is required to be monotonic.
-///   * bool operator==(const ValueT &rhs) const;
-///
+/// This class represents a sparse state that has an optimistic and known value.
+/// This class should be used when the overdefined/underdefined value state is
+/// not finitely representable.
 template <typename ValueT>
-class Lattice : public AbstractSparseLattice {
+class OptimisticSparseState : public AbstractSparseState {
 public:
-  /// Construct a lattice with a known value.
-  explicit Lattice(Value value)
-      : AbstractSparseLattice(value),
-        knownValue(ValueT::getPessimisticValueState(value)) {}
+  template <typename PointT>
+  explicit OptimisticSparseState(PointT point)
+      : knownValue(ValueT::getPessimisticValue(point)) {}
 
   /// Return the value held by this lattice. This requires that the value is
   /// initialized.
@@ -92,16 +112,11 @@ class Lattice : public AbstractSparseLattice {
     return *optimisticValue;
   }
   const ValueT &getValue() const {
-    return const_cast<Lattice<ValueT> *>(this)->getValue();
+    return const_cast<OptimisticSparseState<ValueT> *>(this)->getValue();
   }
 
   /// Returns true if the value of this lattice hasn't yet been initialized.
   bool isUninitialized() const override { return !optimisticValue.hasValue(); }
-  /// Force the initialization of the element by setting it to its pessimistic
-  /// fixpoint.
-  ChangeResult defaultInitialize() override {
-    return markPessimisticFixpoint();
-  }
 
   /// Returns true if the lattice has reached a fixpoint. A fixpoint is when
   /// the information optimistically assumed to be true is the same as the
@@ -110,9 +125,8 @@ class Lattice : public AbstractSparseLattice {
 
   /// Join the information contained in the 'rhs' lattice into this
   /// lattice. Returns if the state of the current lattice changed.
-  ChangeResult join(const AbstractSparseLattice &rhs) override {
-    const Lattice<ValueT> &rhsLattice =
-        static_cast<const Lattice<ValueT> &>(rhs);
+  ChangeResult join(const AbstractSparseState &rhs) override {
+    auto &rhsLattice = static_cast<const OptimisticSparseState<ValueT> &>(rhs);
 
     // If we are at a fixpoint, or rhs is uninitialized, there is nothing to do.
     if (isAtFixpoint() || rhsLattice.isUninitialized())
@@ -122,6 +136,21 @@ class Lattice : public AbstractSparseLattice {
     return join(rhsLattice.getValue());
   }
 
+  ChangeResult meet(const OptimisticSparseState<ValueT> &rhs) {
+    if (isUninitialized())
+      return ChangeResult::NoChange;
+    if (rhs.isUninitialized()) {
+      optimisticValue.reset();
+      return ChangeResult::Change;
+    }
+    ValueT newValue = ValueT::meet(getValue(), rhs.getValue());
+    if (newValue == optimisticValue)
+      return ChangeResult::NoChange;
+
+    optimisticValue = newValue;
+    return ChangeResult::Change;
+  }
+
   /// Join the information contained in the 'rhs' value into this
   /// lattice. Returns if the state of the current lattice changed.
   ChangeResult join(const ValueT &rhs) {
@@ -159,16 +188,14 @@ class Lattice : public AbstractSparseLattice {
     return ChangeResult::Change;
   }
 
-  /// Print the lattice element.
   void print(raw_ostream &os) const override {
-    os << "[";
+    os << '[';
     knownValue.print(os);
-    os << ", ";
-    if (optimisticValue)
+    if (optimisticValue) {
+      os << ", ";
       optimisticValue->print(os);
-    else
-      os << "<NULL>";
-    os << "]";
+    }
+    os << ']';
   }
 
 private:
@@ -206,10 +233,10 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
 
   /// The operation transfer function. Given the operand lattices, this
   /// function is expected to set the result lattices.
-  virtual void
-  visitOperationImpl(Operation *op,
-                     ArrayRef<const AbstractSparseLattice *> operandLattices,
-                     ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
+  virtual void visitOperationImpl(
+      Operation *op,
+      ArrayRef<const AbstractSparseState *> operandLattices,
+      ArrayRef<AbstractSparseElement *> resultLattices) = 0;
 
   /// Given an operation with region control-flow, the lattices of the operands,
   /// and a region successor, compute the lattice values for block arguments
@@ -217,26 +244,29 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
   /// of loops).
   virtual void visitNonControlFlowArgumentsImpl(
       Operation *op, const RegionSuccessor &successor,
-      ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0;
+      ArrayRef<AbstractSparseElement *> argLattices,
+      unsigned firstIndex) = 0;
 
   /// Get the lattice element of a value.
-  virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
+  virtual AbstractSparseElement *getLatticeElement(Value value) = 0;
 
   /// Get a read-only lattice element for a value and add it as a dependency to
   /// a program point.
-  const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
-                                                    Value value);
+  const AbstractSparseState *getLatticeElementFor(ProgramPoint point,
+                                                          Value value);
 
   /// Mark a lattice element as having reached its pessimistic fixpoint and
   /// propgate an update if changed.
-  void markPessimisticFixpoint(AbstractSparseLattice *lattice);
+  void markPessimisticFixpoint(AbstractSparseElement *element);
 
   /// Mark the given lattice elements as having reached their pessimistic
   /// fixpoints and propagate an update if any changed.
-  void markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice *> lattices);
+  void markAllPessimisticFixpoint(
+      ArrayRef<AbstractSparseElement *> elements);
 
   /// Join the lattice element and propagate and update if it changed.
-  void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
+  void join(AbstractSparseElement *lhs,
+            const AbstractSparseState &rhs);
 
 private:
   /// Recursively initialize the analysis on nested operations and blocks.
@@ -255,9 +285,10 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
   /// operation `branch`, which can either be the entry block of one of the
   /// regions or the parent operation itself, and set either the argument or
   /// parent result lattices.
-  void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
-                             Optional<unsigned> successorIndex,
-                             ArrayRef<AbstractSparseLattice *> lattices);
+  void
+  visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
+                        Optional<unsigned> successorIndex,
+                        ArrayRef<AbstractSparseElement *> elements);
 };
 
 //===----------------------------------------------------------------------===//
@@ -267,7 +298,7 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
 /// A sparse (forward) data-flow analysis for propagating SSA value lattices
 /// across the IR by implementing transfer functions for operations.
 ///
-/// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
+/// `StateT` is expected to be a subclass of `AbstractSparseState`.
 template <typename StateT>
 class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
 public:
@@ -276,8 +307,9 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
 
   /// Visit an operation with the lattices of its operands. This function is
   /// expected to set the lattices of the operation's results.
-  virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
-                              ArrayRef<StateT *> results) = 0;
+  virtual void
+  visitOperation(Operation *op, ArrayRef<const StateT *> operands,
+                 ArrayRef<typename StateT::ElementT *> results) = 0;
 
   /// Given an operation with possible region control-flow, the lattices of the
   /// operands, and a region successor, compute the lattice values for block
@@ -285,18 +317,21 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
   /// the bounds of loops). By default, this method marks all such lattice
   /// elements as having reached a pessimistic fixpoint. `firstIndex` is the
   /// index of the first element of `argLattices` that is set by control-flow.
-  virtual void visitNonControlFlowArguments(Operation *op,
-                                            const RegionSuccessor &successor,
-                                            ArrayRef<StateT *> argLattices,
-                                            unsigned firstIndex) {
+  virtual void visitNonControlFlowArguments(
+      Operation *op, const RegionSuccessor &successor,
+      ArrayRef<typename StateT::ElementT *> argLattices, unsigned firstIndex) {
     markAllPessimisticFixpoint(argLattices.take_front(firstIndex));
     markAllPessimisticFixpoint(argLattices.drop_front(
         firstIndex + successor.getSuccessorInputs().size()));
   }
 
 protected:
+  bool staticallyProvides(TypeID stateID, ProgramPoint point) const override {
+    return stateID == TypeID::get<StateT>() && point.is<Value>();
+  }
+
   /// Get the lattice element for a value.
-  StateT *getLatticeElement(Value value) override {
+  typename StateT::ElementT *getLatticeElement(Value value) override {
     return getOrCreate<StateT>(value);
   }
 
@@ -309,32 +344,37 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
 
   /// Mark the lattice elements of a range of values as having reached their
   /// pessimistic fixpoint.
-  void markAllPessimisticFixpoint(ArrayRef<StateT *> lattices) {
+  void
+  markAllPessimisticFixpoint(ArrayRef<typename StateT::ElementT *> elements) {
     AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
-        {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
-         lattices.size()});
+        {reinterpret_cast<AbstractSparseElement *const *>(
+             elements.begin()),
+         elements.size()});
   }
 
 private:
   /// Type-erased wrappers that convert the abstract lattice operands to derived
   /// lattices and invoke the virtual hooks operating on the derived lattices.
   void visitOperationImpl(
-      Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
-      ArrayRef<AbstractSparseLattice *> resultLattices) override {
+      Operation *op,
+      ArrayRef<const AbstractSparseState *> operandLattices,
+      ArrayRef<AbstractSparseElement *> resultLattices) override {
     visitOperation(
         op,
         {reinterpret_cast<const StateT *const *>(operandLattices.begin()),
          operandLattices.size()},
-        {reinterpret_cast<StateT *const *>(resultLattices.begin()),
+        {reinterpret_cast<typename StateT::ElementT *const *>(
+             resultLattices.begin()),
          resultLattices.size()});
   }
   void visitNonControlFlowArgumentsImpl(
       Operation *op, const RegionSuccessor &successor,
-      ArrayRef<AbstractSparseLattice *> argLattices,
+      ArrayRef<AbstractSparseElement *> argLattices,
       unsigned firstIndex) override {
     visitNonControlFlowArguments(
         op, successor,
-        {reinterpret_cast<StateT *const *>(argLattices.begin()),
+        {reinterpret_cast<typename StateT::ElementT *const *>(
+             argLattices.begin()),
          argLattices.size()},
         firstIndex);
   }

diff  --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 2992e05f14dd..ab898a70a8bd 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -45,9 +45,6 @@ inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
   return lhs == ChangeResult::NoChange ? lhs : rhs;
 }
 
-/// Forward declare the analysis state class.
-class AnalysisState;
-
 //===----------------------------------------------------------------------===//
 // GenericProgramPoint
 //===----------------------------------------------------------------------===//
@@ -178,6 +175,8 @@ class DataFlowAnalysis;
 // DataFlowSolver
 //===----------------------------------------------------------------------===//
 
+class AbstractElement;
+
 /// The general data-flow analysis solver. This class is responsible for
 /// orchestrating child data-flow analyses, running the fixed-point iteration
 /// algorithm, managing analysis state and program point memory, and tracking
@@ -202,16 +201,19 @@ class DataFlowSolver {
   /// operation and run the analysis until fixpoint.
   LogicalResult initializeAndRun(Operation *top);
 
-  /// Lookup an analysis state for the given program point. Returns null if one
-  /// does not exist.
   template <typename StateT, typename PointT>
-  const StateT *lookupState(PointT point) const {
-    auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
-    if (it == analysisStates.end())
+  const StateT *lookup(PointT point) const {
+    using ElementT = typename StateT::ElementT;
+    auto it = elements.find({TypeID::get<StateT>(), ProgramPoint(point)});
+    if (it == elements.end())
       return nullptr;
-    return static_cast<const StateT *>(it->second.get());
+    return static_cast<const ElementT &>(*it->second).get();
   }
 
+  template <typename StateT, typename PointT>
+  typename StateT::ElementT *getOrCreate(PointT point);
+
+public:
   /// Get a uniqued program point instance. If one is not present, it is
   /// created with the provided arguments.
   template <typename PointT, typename... Args>
@@ -226,20 +228,9 @@ class DataFlowSolver {
   /// Push a work item onto the worklist.
   void enqueue(WorkItem item) { worklist.push(std::move(item)); }
 
-  /// Get the state associated with the given program point. If it does not
-  /// exist, create an uninitialized state.
-  template <typename StateT, typename PointT>
-  StateT *getOrCreateState(PointT point);
-
-  /// Propagate an update to an analysis state if it changed by pushing
-  /// dependent work items to the back of the queue.
-  void propagateIfChanged(AnalysisState *state, ChangeResult changed);
-
-  /// Add a dependency to an analysis state on a child analysis and program
-  /// point. If the state is updated, the child analysis must be invoked on the
-  /// given program point again.
-  void addDependency(AnalysisState *state, DataFlowAnalysis *analysis,
-                     ProgramPoint point);
+  void getStaticProvidersFor(
+      TypeID stateID, ProgramPoint point,
+      SmallVectorImpl<DataFlowAnalysis *> &staticProviders) const;
 
 private:
   /// The solver's work queue. Work items can be inserted to the front of the
@@ -254,78 +245,129 @@ class DataFlowSolver {
   /// points.
   StorageUniquer uniquer;
 
-  /// A type-erased map of program points to associated analysis states for
-  /// first-class program points.
-  DenseMap<std::pair<ProgramPoint, TypeID>, std::unique_ptr<AnalysisState>>
-      analysisStates;
+  /// A type-erased map of program points to associated analysis states.
+  DenseMap<std::pair<TypeID, ProgramPoint>,
+           std::unique_ptr<AbstractElement>>
+      elements;
 
   /// Allow the base child analysis class to access the internals of the solver.
   friend class DataFlowAnalysis;
 };
 
 //===----------------------------------------------------------------------===//
-// AnalysisState
+// AbstractElement
 //===----------------------------------------------------------------------===//
 
-/// Base class for generic analysis states. Analysis states contain data-flow
-/// information that are attached to program points and which evolve as the
-/// analysis iterates.
-///
-/// This class places no restrictions on the semantics of analysis states beyond
-/// these requirements.
-///
-/// 1. Querying the state of a program point prior to visiting that point
-///    results in uninitialized state. Analyses must be aware of unintialized
-///    states.
-/// 2. Analysis states can reach fixpoints, where subsequent updates will never
-///    trigger a change in the state.
-/// 3. Analysis states that are uninitialized can be forcefully initialized to a
-///    default value.
-class AnalysisState {
+class AbstractState {
 public:
-  virtual ~AnalysisState();
+  virtual ~AbstractState();
 
-  /// Create the analysis state at the given program point.
-  AnalysisState(ProgramPoint point) : point(point) {}
+  virtual void print(raw_ostream &os) const = 0;
+};
 
-  /// Returns true if the analysis state is uninitialized.
-  virtual bool isUninitialized() const = 0;
+/// Subclasses are required to implement `get` and `update`.
+class AbstractElement {
+public:
+  virtual ~AbstractElement();
 
-  /// Force an uninitialized analysis state to initialize itself with a default
-  /// value.
-  virtual ChangeResult defaultInitialize() = 0;
+  explicit AbstractElement(DataFlowSolver &solver, ProgramPoint point)
+      : solver(solver), point(point) {}
 
-  /// Print the contents of the analysis state.
-  virtual void print(raw_ostream &os) const = 0;
+  void addDependency(DataFlowAnalysis *analysis, ProgramPoint point);
+
+  virtual const AbstractState *get() const = 0;
+  virtual void update(DataFlowAnalysis *provider,
+                      function_ref<ChangeResult(AbstractState *)> updateFn) = 0;
 
 protected:
-  /// This function is called by the solver when the analysis state is updated
-  /// to optionally enqueue more work items. For example, if a state tracks
-  /// dependents through the IR (e.g. use-def chains), this function can be
-  /// implemented to push those dependents on the worklist.
-  virtual void onUpdate(DataFlowSolver *solver) const {}
-
-  /// The dependency relations originating from this analysis state. An entry
-  /// `state -> (analysis, point)` is created when `analysis` queries `state`
-  /// when updating `point`.
-  ///
-  /// When this state is updated, all dependent child analysis invocations are
-  /// pushed to the back of the queue. Use a `SetVector` to keep the analysis
-  /// deterministic.
-  ///
-  /// Store the dependents on the analysis state for efficiency.
-  SetVector<DataFlowSolver::WorkItem> dependents;
+  void propagateUpdate();
+
+  virtual void onUpdate() {}
 
-  /// The program point to which the state belongs.
+  DataFlowSolver &solver;
   ProgramPoint point;
 
+private:
+  SetVector<DataFlowSolver::WorkItem, SmallVector<DataFlowSolver::WorkItem>,
+            llvm::SmallDenseSet<DataFlowSolver::WorkItem>>
+      dependents;
+
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  /// When compiling with debugging, keep a name for the analysis state.
+  /// When compiling with debugging, keep a name for the element.
   StringRef debugName;
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
-  /// Allow the framework to access the dependents.
-  friend class DataFlowSolver;
+  friend class ::mlir::DataFlowSolver;
+};
+
+template <typename StateT, typename BaseT = AbstractElement>
+class SingleStateElement : public BaseT {
+public:
+  template <typename PointT>
+  explicit SingleStateElement(DataFlowSolver &solver, PointT point)
+      : BaseT(solver, point), state(point) {}
+
+  const StateT *get() const override { return &state; }
+
+  void update(DataFlowAnalysis *provider,
+              function_ref<ChangeResult(AbstractState *)> updateFn) override {
+    if (updateFn(&state) == ChangeResult::Change)
+      BaseT::propagateUpdate();
+  }
+  void update(DataFlowAnalysis *provider,
+              function_ref<ChangeResult(StateT *)> updateFn) {
+    return update(provider, function_ref<ChangeResult(AbstractState *)>(
+                                [updateFn](AbstractState *state) {
+                                  return updateFn(static_cast<StateT *>(state));
+                                }));
+  }
+
+private:
+  StateT state;
+};
+
+/// StateT is required to implement `join` and `meet`.
+template <typename StateT, typename BaseT = AbstractElement>
+class MultiStateElement : public BaseT {
+public:
+  template <typename PointT>
+  explicit MultiStateElement(DataFlowSolver &solver, PointT point)
+      : BaseT(solver, point), state(point) {
+    SmallVector<DataFlowAnalysis *, 2> staticProviders;
+    solver.getStaticProvidersFor(TypeID::get<StateT>(), point, staticProviders);
+    for (DataFlowAnalysis *staticProvider : staticProviders)
+      states.try_emplace(staticProvider, StateT(point));
+  }
+
+  const StateT *get() const override { return &state; }
+
+  void update(DataFlowAnalysis *provider,
+              function_ref<ChangeResult(AbstractState *)> updateFn) override {
+    auto it = states.find(provider);
+    if (it == states.end()) {
+      if (updateFn(&state) == ChangeResult::Change)
+        BaseT::propagateUpdate();
+      return;
+    }
+    if (updateFn(&it->second) == ChangeResult::NoChange)
+      return;
+    StateT newState(it->second);
+    for (auto &entry : states)
+      (void)newState.meet(entry.second);
+    if (state.join(newState) == ChangeResult::Change)
+      BaseT::propagateUpdate();
+  }
+  void update(DataFlowAnalysis *provider,
+              function_ref<ChangeResult(StateT *)> updateFn) {
+    return update(provider, function_ref<ChangeResult(AbstractState *)>(
+                                [updateFn](AbstractState *state) {
+                                  return updateFn(static_cast<StateT *>(state));
+                                }));
+  }
+
+private:
+  StateT state;
+  llvm::SmallDenseMap<DataFlowAnalysis *, StateT, 2> states;
 };
 
 //===----------------------------------------------------------------------===//
@@ -385,12 +427,13 @@ class DataFlowAnalysis {
   virtual LogicalResult visit(ProgramPoint point) = 0;
 
 protected:
-  /// Create a dependency between the given analysis state and program point
-  /// on this analysis.
-  void addDependency(AnalysisState *state, ProgramPoint point);
-
-  /// Propagate an update to a state if it changed.
-  void propagateIfChanged(AnalysisState *state, ChangeResult changed);
+  /// Returns true if this analysis *statically* provides values for the given
+  /// state kind for the given program point. This means the analysis will
+  /// always provide values for this state regardless of the state of the
+  /// analysis.
+  virtual bool staticallyProvides(TypeID stateID, ProgramPoint point) const {
+    return false;
+  }
 
   /// Register a custom program point class.
   template <typename PointT>
@@ -404,12 +447,18 @@ class DataFlowAnalysis {
     return solver.getProgramPoint<PointT>(std::forward<Args>(args)...);
   }
 
+  template <typename StateT, typename PointT>
+  void update(PointT point, function_ref<ChangeResult(StateT *)> updateFn) {
+    auto *element = getOrCreate<StateT>(point);
+    element->update(this, updateFn);
+  }
+
   /// Get the analysis state assiocated with the program point. The returned
   /// state is expected to be "write-only", and any updates need to be
   /// propagated by `propagateIfChanged`.
   template <typename StateT, typename PointT>
-  StateT *getOrCreate(PointT point) {
-    return solver.getOrCreateState<StateT>(point);
+  typename StateT::ElementT *getOrCreate(PointT point) {
+    return solver.getOrCreate<StateT>(point);
   }
 
   /// Get a read-only analysis state for the given point and create a dependency
@@ -417,14 +466,15 @@ class DataFlowAnalysis {
   /// re-invoked on the dependent.
   template <typename StateT, typename PointT>
   const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) {
-    StateT *state = getOrCreate<StateT>(point);
-    addDependency(state, dependent);
-    return state;
+    auto *element = getOrCreate<StateT>(point);
+    element->addDependency(this, dependent);
+    return element->get();
   }
 
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   /// When compiling with debugging, keep a name for the analyis.
   StringRef debugName;
+  friend class AbstractElement;
 #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
 private:
@@ -445,19 +495,21 @@ AnalysisT *DataFlowSolver::load(Args &&...args) {
 }
 
 template <typename StateT, typename PointT>
-StateT *DataFlowSolver::getOrCreateState(PointT point) {
-  std::unique_ptr<AnalysisState> &state =
-      analysisStates[{ProgramPoint(point), TypeID::get<StateT>()}];
-  if (!state) {
-    state = std::unique_ptr<StateT>(new StateT(point));
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-    state->debugName = llvm::getTypeName<StateT>();
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
+typename StateT::ElementT *DataFlowSolver::getOrCreate(PointT point) {
+  using ElementT = typename StateT::ElementT;
+  static_assert(std::is_base_of<AbstractElement, ElementT>::value,
+                "expected an abstract element");
+  std::unique_ptr<AbstractElement> &element =
+      elements[{TypeID::get<StateT>(), ProgramPoint(point)}];
+  if (!element) {
+    element = std::unique_ptr<ElementT>(new ElementT(*this, point));
+    element->debugName = llvm::getTypeName<StateT>();
   }
-  return static_cast<StateT *>(state.get());
+  return static_cast<ElementT *>(element.get());
 }
 
-inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) {
+inline raw_ostream &operator<<(raw_ostream &os,
+                               const AbstractState &state) {
   state.print(os);
   return os;
 }

diff  --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index 386237e47b5e..39affbdcfb34 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -30,8 +30,8 @@ void ConstantValue::print(raw_ostream &os) const {
 //===----------------------------------------------------------------------===//
 
 void SparseConstantPropagation::visitOperation(
-    Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
-    ArrayRef<Lattice<ConstantValue> *> results) {
+    Operation *op, ArrayRef<const ConstantValueState *> operands,
+    ArrayRef<ConstantValueState::ElementT *> results) {
   LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
 
   // Don't try to simulate the results of a region operation as we can't
@@ -39,12 +39,12 @@ void SparseConstantPropagation::visitOperation(
   // folds as the desire here is for simulated execution, and not general
   // folding.
   if (op->getNumRegions())
-    return;
+    return markAllPessimisticFixpoint(results);
 
   SmallVector<Attribute, 8> constantOperands;
   constantOperands.reserve(op->getNumOperands());
-  for (auto *operandLattice : operands)
-    constantOperands.push_back(operandLattice->getValue().getConstantValue());
+  for (auto *operandState : operands)
+    constantOperands.push_back(operandState->getValue().getConstantValue());
 
   // Save the original operands and attributes just in case the operation
   // folds in-place. The constant passed in may not correspond to the real
@@ -56,10 +56,8 @@ void SparseConstantPropagation::visitOperation(
   // fails or was an in-place fold, mark the results as overdefined.
   SmallVector<OpFoldResult, 8> foldResults;
   foldResults.reserve(op->getNumResults());
-  if (failed(op->fold(constantOperands, foldResults))) {
-    markAllPessimisticFixpoint(results);
-    return;
-  }
+  if (failed(op->fold(constantOperands, foldResults)))
+    return markAllPessimisticFixpoint(results);
 
   // If the folding was in-place, mark the results as overdefined and reset
   // the operation. We don't allow in-place folds as the desire here is for
@@ -67,25 +65,26 @@ void SparseConstantPropagation::visitOperation(
   if (foldResults.empty()) {
     op->setOperands(originalOperands);
     op->setAttrs(originalAttrs);
-    return;
+    return markAllPessimisticFixpoint(results);
   }
 
   // Merge the fold results into the lattice for this operation.
   assert(foldResults.size() == op->getNumResults() && "invalid result size");
   for (const auto it : llvm::zip(results, foldResults)) {
-    Lattice<ConstantValue> *lattice = std::get<0>(it);
+    ConstantValueState::ElementT *element = std::get<0>(it);
 
     // Merge in the result of the fold, either a constant or a value.
     OpFoldResult foldResult = std::get<1>(it);
     if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
       LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
-      propagateIfChanged(lattice,
-                         lattice->join(ConstantValue(attr, op->getDialect())));
+      element->update(this, [attr, op](ConstantValueState *state) {
+        return state->join(ConstantValue(attr, op->getDialect()));
+      });
     } else {
       LLVM_DEBUG(llvm::dbgs()
                  << "Folded to value: " << foldResult.get<Value>() << "\n");
       AbstractSparseDataFlowAnalysis::join(
-          lattice, *getLatticeElement(foldResult.get<Value>()));
+          element, *getLatticeElement(foldResult.get<Value>())->get());
     }
   }
 }

diff  --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 5e5215144c39..6ff6fb4c3735 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -27,24 +27,6 @@ void Executable::print(raw_ostream &os) const {
   os << (live ? "live" : "dead");
 }
 
-void Executable::onUpdate(DataFlowSolver *solver) const {
-  if (auto *block = point.dyn_cast<Block *>()) {
-    // Re-invoke the analyses on the block itself.
-    for (DataFlowAnalysis *analysis : subscribers)
-      solver->enqueue({block, analysis});
-    // Re-invoke the analyses on all operations in the block.
-    for (DataFlowAnalysis *analysis : subscribers)
-      for (Operation &op : *block)
-        solver->enqueue({&op, analysis});
-  } else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
-    // Re-invoke the analysis on the successor block.
-    if (auto *edge = dyn_cast<CFGEdge>(programPoint)) {
-      for (DataFlowAnalysis *analysis : subscribers)
-        solver->enqueue({edge->getTo(), analysis});
-    }
-  }
-}
-
 //===----------------------------------------------------------------------===//
 // PredecessorState
 //===----------------------------------------------------------------------===//
@@ -104,8 +86,8 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
   for (Region &region : top->getRegions()) {
     if (region.empty())
       continue;
-    auto *state = getOrCreate<Executable>(&region.front());
-    propagateIfChanged(state, state->setToLive());
+    update<Executable>(&region.front(),
+                       [](Executable *state) { return state->setToLive(); });
   }
 
   // Mark as overdefined the predecessors of symbol callables with potentially
@@ -132,8 +114,9 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
       // Public symbol callables or those for which we can't see all uses have
       // potentially unknown callsites.
       if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
-        auto *state = getOrCreate<PredecessorState>(callable);
-        propagateIfChanged(state, state->setHasUnknownPredecessors());
+        update<PredecessorState>(callable, [](PredecessorState *state) {
+          return state->setHasUnknownPredecessors();
+        });
       }
       foundSymbolCallable = true;
     }
@@ -149,8 +132,9 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
       // If we couldn't gather the symbol uses, conservatively assume that
       // we can't track information for any nested symbols.
       return top->walk([&](CallableOpInterface callable) {
-        auto *state = getOrCreate<PredecessorState>(callable);
-        propagateIfChanged(state, state->setHasUnknownPredecessors());
+        update<PredecessorState>(callable, [](PredecessorState *state) {
+          return state->setHasUnknownPredecessors();
+        });
       });
     }
 
@@ -160,12 +144,12 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
       // If a callable symbol has a non-call use, then we can't be guaranteed to
       // know all callsites.
       Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
-      auto *state = getOrCreate<PredecessorState>(symbol);
-      propagateIfChanged(state, state->setHasUnknownPredecessors());
+      update<PredecessorState>(symbol, [](PredecessorState *state) {
+        return state->setHasUnknownPredecessors();
+      });
     }
   };
-  SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
-                                walkFn);
+  SymbolTable::walkSymbolTables(top, !top->getBlock(), walkFn);
 }
 
 /// Returns true if the operation terminates a block. It is insufficient to
@@ -198,18 +182,17 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
 }
 
 void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
-  auto *state = getOrCreate<Executable>(to);
-  propagateIfChanged(state, state->setToLive());
-  auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
-  propagateIfChanged(edgeState, edgeState->setToLive());
+  update<Executable>(to, [](Executable *state) { return state->setToLive(); });
+  update<Executable>(getProgramPoint<CFGEdge>(from, to),
+                     [](Executable *state) { return state->setToLive(); });
 }
 
 void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
   for (Region &region : op->getRegions()) {
     if (region.empty())
       continue;
-    auto *state = getOrCreate<Executable>(&region.front());
-    propagateIfChanged(state, state->setToLive());
+    update<Executable>(&region.front(),
+                       [](Executable *state) { return state->setToLive(); });
   }
 }
 
@@ -221,7 +204,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
     return emitError(point.getLoc(), "unknown program point kind");
 
   // If the parent block is not executable, there is nothing to do.
-  if (!getOrCreate<Executable>(op->getBlock())->isLive())
+  if (!getOrCreate<Executable>(op->getBlock())->get()->isLive())
     return success();
 
   // We have a live call op. Add this as a live predecessor of the callee.
@@ -296,25 +279,27 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
   if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
       !isExternalCallable(callableOp)) {
     // Add the live callsite.
-    auto *callsites = getOrCreate<PredecessorState>(callableOp);
-    propagateIfChanged(callsites, callsites->join(call));
+    update<PredecessorState>(callableOp, [call](PredecessorState *state) {
+      return state->join(call);
+    });
   } else {
     // Mark this call op's predecessors as overdefined.
-    auto *predecessors = getOrCreate<PredecessorState>(call);
-    propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
+    update<PredecessorState>(call, [](PredecessorState *state) {
+      return state->setHasUnknownPredecessors();
+    });
   }
 }
 
 /// Get the constant values of the operands of an operation. If any of the
 /// constant value lattices are uninitialized, return none to indicate the
 /// analysis should bail out.
-static Optional<SmallVector<Attribute>> getOperandValuesImpl(
-    Operation *op,
-    function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
+static Optional<SmallVector<Attribute>>
+getOperandValuesImpl(Operation *op,
+                     function_ref<const ConstantValueState *(Value)> getState) {
   SmallVector<Attribute> operands;
   operands.reserve(op->getNumOperands());
   for (Value operand : op->getOperands()) {
-    const Lattice<ConstantValue> *cv = getLattice(operand);
+    const ConstantValueState *cv = getState(operand);
     // If any of the operands' values are uninitialized, bail out.
     if (cv->isUninitialized())
       return {};
@@ -325,11 +310,12 @@ static Optional<SmallVector<Attribute>> getOperandValuesImpl(
 
 Optional<SmallVector<Attribute>>
 DeadCodeAnalysis::getOperandValues(Operation *op) {
-  return getOperandValuesImpl(op, [&](Value value) {
-    auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
-    lattice->useDefSubscribe(this);
-    return lattice;
-  });
+  return getOperandValuesImpl(
+      op, [&](Value value) -> const ConstantValueState * {
+        auto *element = getOrCreate<ConstantValueState>(value);
+        element->useDefSubscribe(this);
+        return element->get();
+      });
 }
 
 void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
@@ -362,13 +348,13 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
                              ? &successor.getSuccessor()->front()
                              : ProgramPoint(branch);
     // Mark the entry block as executable.
-    auto *state = getOrCreate<Executable>(point);
-    propagateIfChanged(state, state->setToLive());
+    update<Executable>(point,
+                       [](Executable *state) { return state->setToLive(); });
     // Add the parent op as a predecessor.
-    auto *predecessors = getOrCreate<PredecessorState>(point);
-    propagateIfChanged(
-        predecessors,
-        predecessors->join(branch, successor.getSuccessorInputs()));
+    update<PredecessorState>(
+        point, [branch, &successor](PredecessorState *state) {
+          return state->join(branch, successor.getSuccessorInputs());
+        });
   }
 }
 
@@ -385,17 +371,20 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
   // Mark successor region entry blocks as executable and add this op to the
   // list of predecessors.
   for (const RegionSuccessor &successor : successors) {
-    PredecessorState *predecessors;
     if (Region *region = successor.getSuccessor()) {
-      auto *state = getOrCreate<Executable>(&region->front());
-      propagateIfChanged(state, state->setToLive());
-      predecessors = getOrCreate<PredecessorState>(&region->front());
+      update<Executable>(&region->front(),
+                         [](Executable *state) { return state->setToLive(); });
+      update<PredecessorState>(
+          &region->front(), [op, &successor](PredecessorState *state) {
+            return state->join(op, successor.getSuccessorInputs());
+          });
     } else {
       // Add this terminator as a predecessor to the parent op.
-      predecessors = getOrCreate<PredecessorState>(branch);
+      update<PredecessorState>(
+          branch, [op, &successor](PredecessorState *state) {
+            return state->join(op, successor.getSuccessorInputs());
+          });
     }
-    propagateIfChanged(predecessors,
-                       predecessors->join(op, successor.getSuccessorInputs()));
   }
 }
 
@@ -412,12 +401,14 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
     assert(isa<CallOpInterface>(predecessor));
     auto *predecessors = getOrCreate<PredecessorState>(predecessor);
     if (canResolve) {
-      propagateIfChanged(predecessors, predecessors->join(op));
+      predecessors->update(
+          this, [op](PredecessorState *state) { return state->join(op); });
     } else {
       // If the terminator is not a return-like, then conservatively assume we
       // can't resolve the predecessor.
-      propagateIfChanged(predecessors,
-                         predecessors->setHasUnknownPredecessors());
+      predecessors->update(this, [](PredecessorState *state) {
+        return state->setHasUnknownPredecessors();
+      });
     }
   }
 }

diff  --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 95c4aadc7592..957880f6c999 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -45,8 +45,8 @@ void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) {
     return;
 
   // Get the dense lattice to update.
-  AbstractDenseLattice *after = getLattice(op);
-  if (after->isAtFixpoint())
+  AbstractDenseElement *after = getLattice(op);
+  if (after->get()->isAtFixpoint())
     return;
 
   // If this op implements region control-flow, then control-flow dictates its
@@ -61,14 +61,17 @@ void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) {
     // If not all return sites are known, then conservatively assume we can't
     // reason about the data-flow.
     if (!predecessors->allPredecessorsKnown())
-      return reset(after);
-    for (Operation *predecessor : predecessors->getKnownPredecessors())
-      join(after, *getLatticeFor(op, predecessor));
-    return;
+      return markPessimisticFixpoint(after);
+    return update(after, [this, predecessors, op](AbstractDenseState *state) {
+      ChangeResult result = ChangeResult::NoChange;
+      for (Operation *predecessor : predecessors->getKnownPredecessors())
+        result |= state->join(*getLatticeFor(op, predecessor));
+      return result;
+    });
   }
 
   // Get the dense state before the execution of the op.
-  const AbstractDenseLattice *before;
+  const AbstractDenseState *before;
   if (Operation *prev = op->getPrevNode())
     before = getLatticeFor(op, prev);
   else
@@ -87,8 +90,8 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
     return;
 
   // Get the dense lattice to update.
-  AbstractDenseLattice *after = getLattice(block);
-  if (after->isAtFixpoint())
+  AbstractDenseElement *after = getLattice(block);
+  if (after->get()->isAtFixpoint())
     return;
 
   // The dense lattices of entry blocks are set by region control-flow or the
@@ -101,15 +104,17 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
       // If not all callsites are known, conservatively mark all lattices as
       // having reached their pessimistic fixpoints.
       if (!callsites->allPredecessorsKnown())
-        return reset(after);
-      for (Operation *callsite : callsites->getKnownPredecessors()) {
-        // Get the dense lattice before the callsite.
-        if (Operation *prev = callsite->getPrevNode())
-          join(after, *getLatticeFor(block, prev));
-        else
-          join(after, *getLatticeFor(block, callsite->getBlock()));
-      }
-      return;
+        return markPessimisticFixpoint(after);
+      return update(after, [this, callsites, block](AbstractDenseState *state) {
+        ChangeResult result = ChangeResult::NoChange;
+        for (Operation *callsite : callsites->getKnownPredecessors()) {
+          if (Operation *prev = callsite->getPrevNode())
+            result |= state->join(*getLatticeFor(block, prev));
+          else
+            result |= state->join(*getLatticeFor(block, callsite->getBlock()));
+        }
+        return result;
+      });
     }
 
     // Check if we can reason about the control-flow.
@@ -117,53 +122,62 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
       return visitRegionBranchOperation(block, branch, after);
 
     // Otherwise, we can't reason about the data-flow.
-    return reset(after);
+    return markPessimisticFixpoint(after);
   }
 
   // Join the state with the state after the block's predecessors.
-  for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
-       it != e; ++it) {
-    // Skip control edges that aren't executable.
-    Block *predecessor = *it;
-    if (!getOrCreateFor<Executable>(
-             block, getProgramPoint<CFGEdge>(predecessor, block))
-             ->isLive())
-      continue;
-
-    // Merge in the state from the predecessor's terminator.
-    join(after, *getLatticeFor(block, predecessor->getTerminator()));
-  }
+  update(after, [this, block](AbstractDenseState *state) {
+    ChangeResult result = ChangeResult::NoChange;
+    for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
+         it != e; ++it) {
+      // Skip control edges that aren't executable.
+      Block *predecessor = *it;
+      if (!getOrCreateFor<Executable>(
+               block, getProgramPoint<CFGEdge>(predecessor, block))
+               ->isLive())
+        continue;
+
+      // Merge in the state from the predecessor's terminator.
+      result |=
+          state->join(*getLatticeFor(block, predecessor->getTerminator()));
+    }
+    return result;
+  });
 }
 
 void AbstractDenseDataFlowAnalysis::visitRegionBranchOperation(
     ProgramPoint point, RegionBranchOpInterface branch,
-    AbstractDenseLattice *after) {
+    AbstractDenseElement *after) {
   // Get the terminator predecessors.
   const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
   assert(predecessors->allPredecessorsKnown() &&
          "unexpected unresolved region successors");
 
-  for (Operation *op : predecessors->getKnownPredecessors()) {
-    const AbstractDenseLattice *before;
-    // If the predecessor is the parent, get the state before the parent.
-    if (op == branch) {
-      if (Operation *prev = op->getPrevNode())
-        before = getLatticeFor(point, prev);
-      else
-        before = getLatticeFor(point, op->getBlock());
-
-      // Otherwise, get the state after the terminator.
-    } else {
-      before = getLatticeFor(point, op);
+  update(after, [&](AbstractDenseState *state) {
+    ChangeResult result = ChangeResult::NoChange;
+    for (Operation *op : predecessors->getKnownPredecessors()) {
+      const AbstractDenseState *before;
+      // If the predecessor is the parent, get the state before the parent.
+      if (op == branch) {
+        if (Operation *prev = op->getPrevNode())
+          before = getLatticeFor(point, prev);
+        else
+          before = getLatticeFor(point, op->getBlock());
+
+        // Otherwise, get the state after the terminator.
+      } else {
+        before = getLatticeFor(point, op);
+      }
+      result |= state->join(*before);
     }
-    join(after, *before);
-  }
+    return result;
+  });
 }
 
-const AbstractDenseLattice *
+const AbstractDenseState *
 AbstractDenseDataFlowAnalysis::getLatticeFor(ProgramPoint dependent,
                                              ProgramPoint point) {
-  AbstractDenseLattice *state = getLattice(point);
-  addDependency(state, dependent);
-  return state;
+  AbstractDenseElement *element = getLattice(point);
+  element->addDependency(this, dependent);
+  return element->get();
 }

diff  --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index e983341faf02..434f7d44f751 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -23,7 +23,7 @@
 using namespace mlir;
 using namespace mlir::dataflow;
 
-IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
+IntegerValueRange IntegerValueRange::getPessimisticValue(Value value) {
   unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
   APInt umin = APInt::getMinValue(width);
   APInt umax = APInt::getMaxValue(width);
@@ -32,30 +32,9 @@ IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
   return {{umin, umax, smin, smax}};
 }
 
-void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
-  Lattice::onUpdate(solver);
-
-  // If the integer range can be narrowed to a constant, update the constant
-  // value of the SSA value.
-  Optional<APInt> constant = getValue().getValue().getConstantValue();
-  auto value = point.get<Value>();
-  auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
-  if (!constant)
-    return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint());
-
-  Dialect *dialect;
-  if (auto *parent = value.getDefiningOp())
-    dialect = parent->getDialect();
-  else
-    dialect = value.getParentBlock()->getParentOp()->getDialect();
-  solver->propagateIfChanged(
-      cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
-                                 dialect)));
-}
-
 void IntegerRangeAnalysis::visitOperation(
-    Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
-    ArrayRef<IntegerValueRangeLattice *> results) {
+    Operation *op, ArrayRef<const IntegerValueRangeState *> operands,
+    ArrayRef<IntegerValueRangeState::ElementT *> results) {
   // Ignore non-integer outputs - return early if the op has no scalar
   // integer results
   bool hasIntegerResult = false;
@@ -63,8 +42,9 @@ void IntegerRangeAnalysis::visitOperation(
     if (std::get<1>(it).getType().isIntOrIndex()) {
       hasIntegerResult = true;
     } else {
-      propagateIfChanged(std::get<0>(it),
-                         std::get<0>(it)->markPessimisticFixpoint());
+      std::get<0>(it)->update(this, [](IntegerValueRangeState *state) {
+        return state->markPessimisticFixpoint();
+      });
     }
   }
   if (!hasIntegerResult)
@@ -76,7 +56,7 @@ void IntegerRangeAnalysis::visitOperation(
 
   LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
   SmallVector<ConstantIntRanges> argRanges(
-      llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
+      llvm::map_range(operands, [](const IntegerValueRangeState *val) {
         return val->getValue().getValue();
       }));
 
@@ -87,26 +67,28 @@ void IntegerRangeAnalysis::visitOperation(
     assert(llvm::find(op->getResults(), result) != op->result_end());
 
     LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
-    IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
-    Optional<IntegerValueRange> oldRange;
-    if (!lattice->isUninitialized())
-      oldRange = lattice->getValue();
-
-    ChangeResult changed = lattice->join(attrs);
-
-    // Catch loop results with loop variant bounds and conservatively make
-    // them [-inf, inf] so we don't circle around infinitely often (because
-    // the dataflow analysis in MLIR doesn't attempt to work out trip counts
-    // and often can't).
-    bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
-      return op->hasTrait<OpTrait::IsTerminator>();
-    });
-    if (isYieldedResult && oldRange.hasValue() &&
-        !(lattice->getValue() == *oldRange)) {
-      LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
-      changed |= lattice->markPessimisticFixpoint();
-    }
-    propagateIfChanged(lattice, changed);
+    results[result.getResultNumber()]->update(
+        this, [&](IntegerValueRangeState *state) {
+          Optional<IntegerValueRange> oldRange;
+          if (!state->isUninitialized())
+            oldRange = state->getValue();
+
+          ChangeResult changed = state->join(attrs);
+
+          // Catch loop results with loop variant bounds and conservatively make
+          // them [-inf, inf] so we don't circle around infinitely often
+          // (because the dataflow analysis in MLIR doesn't attempt to work out
+          // trip counts and often can't).
+          bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
+            return op->hasTrait<OpTrait::IsTerminator>();
+          });
+          if (isYieldedResult && oldRange.hasValue() &&
+              !(state->getValue() == *oldRange)) {
+            LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+            changed |= state->markPessimisticFixpoint();
+          }
+          return changed;
+        });
   };
 
   inferrable.inferResultRanges(argRanges, joinCallback);
@@ -114,7 +96,8 @@ void IntegerRangeAnalysis::visitOperation(
 
 void IntegerRangeAnalysis::visitNonControlFlowArguments(
     Operation *op, const RegionSuccessor &successor,
-    ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
+    ArrayRef<IntegerValueRangeState::ElementT *> argLattices,
+    unsigned firstIndex) {
   if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
     LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
     SmallVector<ConstantIntRanges> argRanges(
@@ -131,25 +114,28 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
         return;
 
       LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
-      IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
-      Optional<IntegerValueRange> oldRange;
-      if (!lattice->isUninitialized())
-        oldRange = lattice->getValue();
-
-      ChangeResult changed = lattice->join(attrs);
-
-      // Catch loop results with loop variant bounds and conservatively make
-      // them [-inf, inf] so we don't circle around infinitely often (because
-      // the dataflow analysis in MLIR doesn't attempt to work out trip counts
-      // and often can't).
-      bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
-        return op->hasTrait<OpTrait::IsTerminator>();
-      });
-      if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
-        LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
-        changed |= lattice->markPessimisticFixpoint();
-      }
-      propagateIfChanged(lattice, changed);
+      argLattices[arg.getArgNumber()]->update(
+          this, [&](IntegerValueRangeState *state) {
+            Optional<IntegerValueRange> oldRange;
+            if (!state->isUninitialized())
+              oldRange = state->getValue();
+
+            ChangeResult changed = state->join(attrs);
+
+            // Catch loop results with loop variant bounds and conservatively
+            // make them [-inf, inf] so we don't circle around infinitely often
+            // (because the dataflow analysis in MLIR doesn't attempt to work
+            // out trip counts and often can't).
+            bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
+              return op->hasTrait<OpTrait::IsTerminator>();
+            });
+            if (isYieldedValue && oldRange &&
+                !(state->getValue() == *oldRange)) {
+              LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+              changed |= state->markPessimisticFixpoint();
+            }
+            return changed;
+          });
     };
 
     inferrable.inferResultRanges(argRanges, joinCallback);
@@ -168,11 +154,9 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
                 loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
           return bound.getValue();
       } else if (auto value = loopBound->dyn_cast<Value>()) {
-        const IntegerValueRangeLattice *lattice =
-            getLatticeElementFor(op, value);
-        if (lattice != nullptr)
-          return getUpper ? lattice->getValue().getValue().smax()
-                          : lattice->getValue().getValue().smin();
+        const IntegerValueRangeState *state = getLatticeElementFor(op, value);
+        return getUpper ? state->getValue().getValue().smax()
+                        : state->getValue().getValue().smin();
       }
     }
     // Given the results of getConstant{Lower,Upper}Bound()
@@ -192,13 +176,10 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
     Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
     Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
     Optional<OpFoldResult> step = loop.getSingleStep();
-    APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
-                                     /*getUpper=*/false);
-    APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
-                                     /*getUpper=*/true);
+    APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), false);
+    APInt max = getLoopBoundFromFold(upperBound, iv->getType(), true);
     // Assume positivity for uniscoverable steps by way of getUpper = true.
-    APInt stepVal =
-        getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
+    APInt stepVal = getLoopBoundFromFold(step, iv->getType(), true);
 
     if (stepVal.isNegative()) {
       std::swap(min, max);
@@ -208,12 +189,53 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
       max -= 1;
     }
 
-    IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
+    auto *ivEntry = getLatticeElement(*iv);
     auto ivRange = ConstantIntRanges::fromSigned(min, max);
-    propagateIfChanged(ivEntry, ivEntry->join(ivRange));
-    return;
+    return ivEntry->update(this, [&ivRange](IntegerValueRangeState *state) {
+      return state->join(ivRange);
+    });
   }
 
   return SparseDataFlowAnalysis::visitNonControlFlowArguments(
       op, successor, argLattices, firstIndex);
 }
+
+LogicalResult IntegerRangeToConstant::initialize(Operation *top) {
+  auto visitValues = [this](ValueRange values) {
+    for (Value value : values)
+      (void)visit(value);
+  };
+  top->walk([&](Operation *op) {
+    visitValues(op->getResults());
+    for (Region &region : op->getRegions())
+      for (Block &block : region)
+        visitValues(block.getArguments());
+  });
+  return success();
+}
+
+LogicalResult IntegerRangeToConstant::visit(ProgramPoint point) {
+  auto value = point.get<Value>();
+  auto *rangeState = getOrCreateFor<IntegerValueRangeState>(value, value);
+  if (rangeState->isUninitialized())
+    return success();
+
+  update<ConstantValueState>(value, [&](ConstantValueState *state) {
+    const ConstantIntRanges &range = rangeState->getValue().getValue();
+    // Try to narrow to a constant.
+    Optional<APInt> constant = range.getConstantValue();
+    if (!constant)
+      return state->markPessimisticFixpoint();
+
+    // Find a dialect to materialize the constant.
+    Dialect *dialect;
+    if (Operation *op = value.getDefiningOp())
+      dialect = op->getDialect();
+    else
+      dialect = value.getParentRegion()->getParentOp()->getDialect();
+
+    Attribute attr = IntegerAttr::get(value.getType(), *constant);
+    return state->join(ConstantValue(attr, dialect));
+  });
+  return success();
+}

diff  --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 4caa5eea326a..4a08efc7b6f1 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -12,17 +12,6 @@
 using namespace mlir;
 using namespace mlir::dataflow;
 
-//===----------------------------------------------------------------------===//
-// AbstractSparseLattice
-//===----------------------------------------------------------------------===//
-
-void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
-  // Push all users of the value to the queue.
-  for (Operation *user : point.get<Value>().getUsers())
-    for (DataFlowAnalysis *analysis : useDefSubscribers)
-      solver->enqueue({user, analysis});
-}
-
 //===----------------------------------------------------------------------===//
 // AbstractSparseDataFlowAnalysis
 //===----------------------------------------------------------------------===//
@@ -80,28 +69,26 @@ void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
     return;
 
   // If the containing block is not executable, bail out.
-  if (!getOrCreate<Executable>(op->getBlock())->isLive())
+  if (!getOrCreate<Executable>(op->getBlock())->get()->isLive())
     return;
 
   // Get the result lattices.
-  SmallVector<AbstractSparseLattice *> resultLattices;
-  resultLattices.reserve(op->getNumResults());
+  SmallVector<AbstractSparseElement *> resultElements;
+  resultElements.reserve(op->getNumResults());
   // Track whether all results have reached their fixpoint.
   bool allAtFixpoint = true;
   for (Value result : op->getResults()) {
-    AbstractSparseLattice *resultLattice = getLatticeElement(result);
-    allAtFixpoint &= resultLattice->isAtFixpoint();
-    resultLattices.push_back(resultLattice);
+    AbstractSparseElement *resultElement = getLatticeElement(result);
+    allAtFixpoint &= resultElement->get()->isAtFixpoint();
+    resultElements.push_back(resultElement);
   }
   // If all result lattices have reached a fixpoint, there is nothing to do.
   if (allAtFixpoint)
     return;
 
   // The results of a region branch operation are determined by control-flow.
-  if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
-    return visitRegionSuccessors({branch}, branch,
-                                 /*successorIndex=*/llvm::None, resultLattices);
-  }
+  if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
+    return visitRegionSuccessors({branch}, branch, llvm::None, resultElements);
 
   // The results of a call operation are determined by the callgraph.
   if (auto call = dyn_cast<CallOpInterface>(op)) {
@@ -109,27 +96,27 @@ void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
     // If not all return sites are known, then conservatively assume we can't
     // reason about the data-flow.
     if (!predecessors->allPredecessorsKnown())
-      return markAllPessimisticFixpoint(resultLattices);
+      return markAllPessimisticFixpoint(resultElements);
     for (Operation *predecessor : predecessors->getKnownPredecessors())
-      for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
+      for (auto it : llvm::zip(predecessor->getOperands(), resultElements))
         join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
     return;
   }
 
   // Grab the lattice elements of the operands.
-  SmallVector<const AbstractSparseLattice *> operandLattices;
-  operandLattices.reserve(op->getNumOperands());
+  SmallVector<const AbstractSparseState *> operandStates;
+  operandStates.reserve(op->getNumOperands());
   for (Value operand : op->getOperands()) {
-    AbstractSparseLattice *operandLattice = getLatticeElement(operand);
-    operandLattice->useDefSubscribe(this);
+    AbstractSparseElement *operandElement = getLatticeElement(operand);
+    operandElement->useDefSubscribe(this);
     // If any of the operand states are not initialized, bail out.
-    if (operandLattice->isUninitialized())
+    if (operandElement->get()->isUninitialized())
       return;
-    operandLattices.push_back(operandLattice);
+    operandStates.push_back(operandElement->get());
   }
 
   // Invoke the operation transfer function.
-  visitOperationImpl(op, operandLattices, resultLattices);
+  visitOperationImpl(op, operandStates, resultElements);
 }
 
 void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
@@ -138,17 +125,17 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
     return;
 
   // If the block is not executable, bail out.
-  if (!getOrCreate<Executable>(block)->isLive())
+  if (!getOrCreate<Executable>(block)->get()->isLive())
     return;
 
   // Get the argument lattices.
-  SmallVector<AbstractSparseLattice *> argLattices;
-  argLattices.reserve(block->getNumArguments());
+  SmallVector<AbstractSparseElement *> argElements;
+  argElements.reserve(block->getNumArguments());
   bool allAtFixpoint = true;
   for (BlockArgument argument : block->getArguments()) {
-    AbstractSparseLattice *argLattice = getLatticeElement(argument);
-    allAtFixpoint &= argLattice->isAtFixpoint();
-    argLattices.push_back(argLattice);
+    AbstractSparseElement *argElement = getLatticeElement(argument);
+    allAtFixpoint &= argElement->get()->isAtFixpoint();
+    argElements.push_back(argElement);
   }
   // If all argument lattices have reached their fixpoints, then there is
   // nothing to do.
@@ -165,10 +152,10 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
       // If not all callsites are known, conservatively mark all lattices as
       // having reached their pessimistic fixpoints.
       if (!callsites->allPredecessorsKnown())
-        return markAllPessimisticFixpoint(argLattices);
+        return markAllPessimisticFixpoint(argElements);
       for (Operation *callsite : callsites->getKnownPredecessors()) {
         auto call = cast<CallOpInterface>(callsite);
-        for (auto it : llvm::zip(call.getArgOperands(), argLattices))
+        for (auto it : llvm::zip(call.getArgOperands(), argElements))
           join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
       }
       return;
@@ -177,13 +164,13 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
     // Check if the lattices can be determined from region control flow.
     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
       return visitRegionSuccessors(
-          block, branch, block->getParent()->getRegionNumber(), argLattices);
+          block, branch, block->getParent()->getRegionNumber(), argElements);
     }
 
     // Otherwise, we can't reason about the data-flow.
     return visitNonControlFlowArgumentsImpl(block->getParentOp(),
                                             RegionSuccessor(block->getParent()),
-                                            argLattices, /*firstIndex=*/0);
+                                            argElements, /*firstIndex=*/0);
   }
 
   // Iterate over the predecessors of the non-entry block.
@@ -196,7 +183,7 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
     auto *edgeExecutable =
         getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
     edgeExecutable->blockContentSubscribe(this);
-    if (!edgeExecutable->isLive())
+    if (!edgeExecutable->get()->isLive())
       continue;
 
     // Check if we can reason about the data-flow from the predecessor.
@@ -204,7 +191,7 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
             dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
       SuccessorOperands operands =
           branch.getSuccessorOperands(it.getSuccessorIndex());
-      for (auto &it : llvm::enumerate(argLattices)) {
+      for (auto &it : llvm::enumerate(argElements)) {
         if (Value operand = operands[it.index()]) {
           join(it.value(), *getLatticeElementFor(block, operand));
         } else {
@@ -214,7 +201,7 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
         }
       }
     } else {
-      return markAllPessimisticFixpoint(argLattices);
+      return markAllPessimisticFixpoint(argElements);
     }
   }
 }
@@ -222,7 +209,7 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
 void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
     ProgramPoint point, RegionBranchOpInterface branch,
     Optional<unsigned> successorIndex,
-    ArrayRef<AbstractSparseLattice *> lattices) {
+    ArrayRef<AbstractSparseElement *> elements) {
   const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
   assert(predecessors->allPredecessorsKnown() &&
          "unexpected unresolved region successors");
@@ -242,7 +229,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
 
     if (!operands) {
       // We can't reason about the data-flow.
-      return markAllPessimisticFixpoint(lattices);
+      return markAllPessimisticFixpoint(elements);
     }
 
     ValueRange inputs = predecessors->getSuccessorInputs(op);
@@ -250,7 +237,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
            "expected the same number of successor inputs as operands");
 
     unsigned firstIndex = 0;
-    if (inputs.size() != lattices.size()) {
+    if (inputs.size() != elements.size()) {
       if (auto *op = point.dyn_cast<Operation *>()) {
         if (!inputs.empty())
           firstIndex = inputs.front().cast<OpResult>().getResultNumber();
@@ -258,7 +245,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
             branch,
             RegionSuccessor(
                 branch->getResults().slice(firstIndex, inputs.size())),
-            lattices, firstIndex);
+            elements, firstIndex);
       } else {
         if (!inputs.empty())
           firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
@@ -267,36 +254,42 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
             branch,
             RegionSuccessor(region, region->getArguments().slice(
                                         firstIndex, inputs.size())),
-            lattices, firstIndex);
+            elements, firstIndex);
       }
     }
 
-    for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
+    for (auto it : llvm::zip(*operands, elements.drop_front(firstIndex)))
       join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
   }
 }
 
-const AbstractSparseLattice *
+const AbstractSparseState *
 AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
                                                      Value value) {
-  AbstractSparseLattice *state = getLatticeElement(value);
-  addDependency(state, point);
-  return state;
+  AbstractSparseElement *element = getLatticeElement(value);
+  element->addDependency(this, point);
+  return element->get();
 }
 
 void AbstractSparseDataFlowAnalysis::markPessimisticFixpoint(
-    AbstractSparseLattice *lattice) {
-  propagateIfChanged(lattice, lattice->markPessimisticFixpoint());
+    AbstractSparseElement *element) {
+  element->update(this, [](AbstractState *state) {
+    return static_cast<AbstractSparseState *>(state)
+        ->markPessimisticFixpoint();
+  });
 }
 
 void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
-    ArrayRef<AbstractSparseLattice *> lattices) {
-  for (AbstractSparseLattice *lattice : lattices) {
-    markPessimisticFixpoint(lattice);
-  }
+    ArrayRef<AbstractSparseElement *> elements) {
+  for (AbstractSparseElement *element : elements)
+    markPessimisticFixpoint(element);
 }
 
-void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
-                                          const AbstractSparseLattice &rhs) {
-  propagateIfChanged(lhs, lhs->join(rhs));
+void AbstractSparseDataFlowAnalysis::join(
+    AbstractSparseElement *lhs,
+    const AbstractSparseState &rhs) {
+  lhs->update(this, [&rhs](AbstractState *lhsState) {
+    return static_cast<AbstractSparseState *>(lhsState)->join(rhs);
+  });
 }
+

diff  --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 18d9ba1bd5d6..0eca6263dc4e 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -24,12 +24,6 @@ using namespace mlir;
 
 GenericProgramPoint::~GenericProgramPoint() = default;
 
-//===----------------------------------------------------------------------===//
-// AnalysisState
-//===----------------------------------------------------------------------===//
-
-AnalysisState::~AnalysisState() = default;
-
 //===----------------------------------------------------------------------===//
 // ProgramPoint
 //===----------------------------------------------------------------------===//
@@ -58,6 +52,34 @@ Location ProgramPoint::getLoc() const {
   return get<Block *>()->getParent()->getLoc();
 }
 
+//===----------------------------------------------------------------------===//
+// AbstractState and AbstractElement
+//===----------------------------------------------------------------------===//
+
+AbstractState::~AbstractState() = default;
+AbstractElement::~AbstractElement() = default;
+
+void AbstractElement::addDependency(DataFlowAnalysis *analysis,
+                                            ProgramPoint point) {
+  auto inserted = dependents.insert({point, analysis});
+  (void)inserted;
+  DATAFLOW_DEBUG({
+    if (inserted) {
+      llvm::dbgs() << "Adding dependency from " << debugName << " of "
+                   << this->point << " to " << analysis->debugName << " on "
+                   << point << "\n";
+    }
+  });
+}
+
+void AbstractElement::propagateUpdate() {
+  DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << debugName << " of "
+                              << point << "\nValue: " << *get() << "\n");
+  for (auto &item : dependents)
+    solver.enqueue(item);
+  onUpdate();
+}
+
 //===----------------------------------------------------------------------===//
 // DataFlowSolver
 //===----------------------------------------------------------------------===//
@@ -94,30 +116,12 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
   return success();
 }
 
-void DataFlowSolver::propagateIfChanged(AnalysisState *state,
-                                        ChangeResult changed) {
-  if (changed == ChangeResult::Change) {
-    DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
-                                << " of " << state->point << "\n"
-                                << "Value: " << *state << "\n");
-    for (const WorkItem &item : state->dependents)
-      enqueue(item);
-    state->onUpdate(this);
-  }
-}
-
-void DataFlowSolver::addDependency(AnalysisState *state,
-                                   DataFlowAnalysis *analysis,
-                                   ProgramPoint point) {
-  auto inserted = state->dependents.insert({point, analysis});
-  (void)inserted;
-  DATAFLOW_DEBUG({
-    if (inserted) {
-      llvm::dbgs() << "Creating dependency between " << state->debugName
-                   << " of " << state->point << "\nand " << analysis->debugName
-                   << " on " << point << "\n";
-    }
-  });
+void DataFlowSolver::getStaticProvidersFor(
+    TypeID stateID, ProgramPoint point,
+    SmallVectorImpl<DataFlowAnalysis *> &staticProviders) const {
+  for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses))
+    if (analysis.staticallyProvides(stateID, point))
+      staticProviders.push_back(&analysis);
 }
 
 //===----------------------------------------------------------------------===//
@@ -127,12 +131,3 @@ void DataFlowSolver::addDependency(AnalysisState *state,
 DataFlowAnalysis::~DataFlowAnalysis() = default;
 
 DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
-
-void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) {
-  solver.addDependency(state, this, point);
-}
-
-void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,
-                                          ChangeResult changed) {
-  solver.propagateIfChanged(state, changed);
-}

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
index 49d0ac70a604..9ce9cabdcbff 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
@@ -23,8 +23,8 @@ using namespace mlir::dataflow;
 /// bound on its value (if it is treated as signed) and that bound is
 /// non-negative.
 static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
-  auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
-  if (!result)
+  auto *result = solver.lookup<IntegerValueRangeState>(v);
+  if (!result || result->isUninitialized())
     return failure();
   const ConstantIntRanges &range = result->getValue().getValue();
   return success(range.smin().isNonNegative());
@@ -113,6 +113,7 @@ struct ArithmeticUnsignedWhenEquivalentPass
     DataFlowSolver solver;
     solver.load<DeadCodeAnalysis>();
     solver.load<IntegerRangeAnalysis>();
+    solver.load<IntegerRangeToConstant>();
     if (failed(solver.initializeAndRun(op)))
       return signalPassFailure();
 

diff  --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 27cda9be50ab..b0e447b9b14d 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -37,7 +37,7 @@ using namespace mlir::dataflow;
 static LogicalResult replaceWithConstant(DataFlowSolver &solver,
                                          OpBuilder &builder,
                                          OperationFolder &folder, Value value) {
-  auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
+  auto *lattice = solver.lookup<ConstantValueState>(value);
   if (!lattice || lattice->isUninitialized())
     return failure();
   const ConstantValue &latticeValue = lattice->getValue();

diff  --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
index 27e994cce3b6..d12e57e1f509 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
@@ -29,7 +29,7 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
         os << "  ";
         block.printAsOperand(os);
         os << " = ";
-        auto *live = solver.lookupState<Executable>(&block);
+        auto *live = solver.lookup<Executable>(&block);
         if (live)
           os << *live;
         else
@@ -39,7 +39,7 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
           os << "   from ";
           pred->printAsOperand(os);
           os << " = ";
-          auto *live = solver.lookupState<Executable>(
+          auto *live = solver.lookup<Executable>(
               solver.getProgramPoint<CFGEdge>(pred, &block));
           if (live)
             os << *live;
@@ -49,12 +49,12 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
         }
       }
       if (!region.empty()) {
-        auto *preds = solver.lookupState<PredecessorState>(&region.front());
+        auto *preds = solver.lookup<PredecessorState>(&region.front());
         if (preds)
           os << "region_preds: " << *preds << "\n";
       }
     }
-    auto *preds = solver.lookupState<PredecessorState>(op);
+    auto *preds = solver.lookup<PredecessorState>(op);
     if (preds)
       os << "op_preds: " << *preds << "\n";
   });
@@ -79,9 +79,10 @@ struct ConstantAnalysis : public DataFlowAnalysis {
     Operation *op = point.get<Operation *>();
     Attribute value;
     if (matchPattern(op, m_Constant(&value))) {
-      auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
-      propagateIfChanged(
-          constant, constant->join(ConstantValue(value, op->getDialect())));
+      auto *constant = getOrCreate<ConstantValueState>(op->getResult(0));
+      constant->update(this, [value, op](ConstantValueState *state) {
+        return state->join(ConstantValue(value, op->getDialect()));
+      });
       return success();
     }
     markAllPessimisticFixpoint(op->getResults());
@@ -94,9 +95,10 @@ struct ConstantAnalysis : public DataFlowAnalysis {
   /// pessimistic fixpoint.
   void markAllPessimisticFixpoint(ValueRange values) {
     for (Value value : values) {
-      auto *constantValue = getOrCreate<Lattice<ConstantValue>>(value);
-      propagateIfChanged(constantValue,
-                         constantValue->markPessimisticFixpoint());
+      auto *constant = getOrCreate<ConstantValueState>(value);
+      constant->update(this, [](ConstantValueState *state) {
+        return state->markPessimisticFixpoint();
+      });
     }
   }
 };

diff  --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
index 98ccb6f94747..3dc63146a591 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
@@ -20,9 +20,7 @@ namespace {
 class UnderlyingValue {
 public:
   /// The pessimistic underlying value of a value is itself.
-  static UnderlyingValue getPessimisticValueState(Value value) {
-    return {value};
-  }
+  static UnderlyingValue getPessimisticValue(Value value) { return {value}; }
 
   /// Create an underlying value state with a known underlying value.
   UnderlyingValue(Value underlyingValue = {})
@@ -51,21 +49,21 @@ class UnderlyingValue {
 
 /// This lattice represents, for a given memory resource, the potential last
 /// operations that modified the resource.
-class LastModification : public AbstractDenseLattice {
+class LastModification : public AbstractDenseState {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
 
-  using AbstractDenseLattice::AbstractDenseLattice;
+  using ElementT =
+      SingleStateElement<LastModification, AbstractDenseElement>;
+
+  explicit LastModification(ProgramPoint point) {}
 
   /// The lattice is always initialized.
   bool isUninitialized() const override { return false; }
 
-  /// Initialize the lattice. Does nothing.
-  ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
-
   /// Mark the lattice as having reached its pessimistic fixpoint. That is, the
   /// last modifications of all memory resources are unknown.
-  ChangeResult reset() override {
+  ChangeResult markPessimisticFixpoint() override {
     if (lastMods.empty())
       return ChangeResult::NoChange;
     lastMods.clear();
@@ -76,7 +74,7 @@ class LastModification : public AbstractDenseLattice {
   bool isAtFixpoint() const override { return false; }
 
   /// Join the last modifications.
-  ChangeResult join(const AbstractDenseLattice &lattice) override {
+  ChangeResult join(const AbstractDenseState &lattice) override {
     const auto &rhs = static_cast<const LastModification &>(lattice);
     ChangeResult result = ChangeResult::NoChange;
     for (const auto &mod : rhs.lastMods) {
@@ -135,13 +133,16 @@ class LastModifiedAnalysis : public DenseDataFlowAnalysis<LastModification> {
   /// its reaching definitions is set to empty. If the operation writes to a
   /// resource, then its reaching definition is set to the written value.
   void visitOperation(Operation *op, const LastModification &before,
-                      LastModification *after) override;
+                      LastModification::ElementT *after) override;
 };
 
 /// Define the lattice class explicitly to provide a type ID.
-struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice)
-  using Lattice::Lattice;
+struct UnderlyingValueState : public OptimisticSparseState<UnderlyingValue> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueState)
+
+  using OptimisticSparseState::OptimisticSparseState;
+  using ElementT =
+      SparseElement<UnderlyingValueState, SingleStateElement>;
 };
 
 /// An analysis that uses forwarding of values along control-flow and callgraph
@@ -149,14 +150,14 @@ struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
 /// analysis exists so that the test analysis and pass can test the behaviour of
 /// the dense data-flow analysis on the callgraph.
 class UnderlyingValueAnalysis
-    : public SparseDataFlowAnalysis<UnderlyingValueLattice> {
+    : public SparseDataFlowAnalysis<UnderlyingValueState> {
 public:
   using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
 
   /// The underlying value of the results of an operation are not known.
-  void visitOperation(Operation *op,
-                      ArrayRef<const UnderlyingValueLattice *> operands,
-                      ArrayRef<UnderlyingValueLattice *> results) override {
+  void
+  visitOperation(Operation *op, ArrayRef<const UnderlyingValueState *> operands,
+                 ArrayRef<UnderlyingValueState::ElementT *> results) override {
     markAllPessimisticFixpoint(results);
   }
 };
@@ -165,8 +166,8 @@ class UnderlyingValueAnalysis
 /// Look for the most underlying value of a value.
 static Value getMostUnderlyingValue(
     Value value,
-    function_ref<const UnderlyingValueLattice *(Value)> getUnderlyingValueFn) {
-  const UnderlyingValueLattice *underlying;
+    function_ref<const UnderlyingValueState *(Value)> getUnderlyingValueFn) {
+  const UnderlyingValueState *underlying;
   do {
     underlying = getUnderlyingValueFn(value);
     if (!underlying || underlying->isUninitialized())
@@ -181,38 +182,42 @@ static Value getMostUnderlyingValue(
 
 void LastModifiedAnalysis::visitOperation(Operation *op,
                                           const LastModification &before,
-                                          LastModification *after) {
+                                          LastModification::ElementT *after) {
   auto memory = dyn_cast<MemoryEffectOpInterface>(op);
   // If we can't reason about the memory effects, then conservatively assume we
   // can't deduce anything about the last modifications.
   if (!memory)
-    return reset(after);
+    return markPessimisticFixpoint(after);
 
   SmallVector<MemoryEffects::EffectInstance> effects;
   memory.getEffects(effects);
 
-  ChangeResult result = after->join(before);
-  for (const auto &effect : effects) {
-    Value value = effect.getValue();
+  after->update(this, [&](LastModification *state) {
+    ChangeResult result = state->join(before);
+    for (const auto &effect : effects) {
+      Value value = effect.getValue();
 
-    // If we see an effect on anything other than a value, assume we can't
-    // deduce anything about the last modifications.
-    if (!value)
-      return reset(after);
+      // If we see an effect on anything other than a value, assume we can't
+      // deduce anything about the last modifications.
+      if (!value) {
+        result |= state->markPessimisticFixpoint();
+        break;
+      }
 
-    value = getMostUnderlyingValue(value, [&](Value value) {
-      return getOrCreateFor<UnderlyingValueLattice>(op, value);
-    });
-    if (!value)
-      return;
+      value = getMostUnderlyingValue(value, [&](Value value) {
+        return getOrCreateFor<UnderlyingValueState>(op, value);
+      });
+      if (!value)
+        return ChangeResult::NoChange;
 
-    // Nothing to do for reads.
-    if (isa<MemoryEffects::Read>(effect.getEffect()))
-      continue;
+      // Nothing to do for reads.
+      if (isa<MemoryEffects::Read>(effect.getEffect()))
+        continue;
 
-    result |= after->set(value, op);
-  }
-  propagateIfChanged(after, result);
+      result |= state->set(value, op);
+    }
+    return result;
+  });
 }
 
 namespace {
@@ -240,13 +245,12 @@ struct TestLastModifiedPass
       if (!tag)
         return;
       os << "test_tag: " << tag.getValue() << ":\n";
-      const LastModification *lastMods =
-          solver.lookupState<LastModification>(op);
+      const auto *lastMods = solver.lookup<LastModification>(op);
       assert(lastMods && "expected a dense lattice");
       for (auto &it : llvm::enumerate(op->getOperands())) {
         os << " operand #" << it.index() << "\n";
         Value value = getMostUnderlyingValue(it.value(), [&](Value value) {
-          return solver.lookupState<UnderlyingValueLattice>(value);
+          return solver.lookup<UnderlyingValueState>(value);
         });
         assert(value && "expected an underlying value");
         if (Optional<ArrayRef<Operation *>> lastMod =

diff  --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
index 329be3c5446f..2fd41b25b765 100644
--- a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
+++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp
@@ -14,17 +14,15 @@ using namespace mlir;
 
 namespace {
 /// This analysis state represents an integer that is XOR'd with other states.
-class FooState : public AnalysisState {
+class FooState : public AbstractState {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState)
 
-  using AnalysisState::AnalysisState;
+  using ElementT = SingleStateElement<FooState>;
 
-  /// Default-initialize the state to zero.
-  ChangeResult defaultInitialize() override { return join(0); }
+  explicit FooState(ProgramPoint point) {}
 
-  /// Returns true if the state is uninitialized.
-  bool isUninitialized() const override { return !state; }
+  bool isUninitialized() const { return !state; }
 
   /// Print the integer value or "none" if uninitialized.
   void print(raw_ostream &os) const override {
@@ -99,7 +97,8 @@ LogicalResult FooAnalysis::initialize(Operation *top) {
     return top->emitError("expected a single region top-level op");
 
   // Initialize the top-level state.
-  getOrCreate<FooState>(&top->getRegion(0).front())->join(0);
+  update<FooState>(&top->getRegion(0).front(),
+                   [](FooState *state) { return state->join(0); });
 
   // Visit all nested blocks and operations.
   for (Block &block : top->getRegion(0)) {
@@ -130,35 +129,37 @@ void FooAnalysis::visitBlock(Block *block) {
     // This is the initial state. Let the framework default-initialize it.
     return;
   }
-  FooState *state = getOrCreate<FooState>(block);
-  ChangeResult result = ChangeResult::NoChange;
-  for (Block *pred : block->getPredecessors()) {
-    // Join the state at the terminators of all predecessors.
-    const FooState *predState =
-        getOrCreateFor<FooState>(block, pred->getTerminator());
-    result |= state->join(*predState);
-  }
-  propagateIfChanged(state, result);
+  update<FooState>(block, [&](FooState *state) {
+    ChangeResult result = ChangeResult::NoChange;
+    for (Block *pred : block->getPredecessors()) {
+      // Join the state at the terminators of all predecessors.
+      const FooState *predState =
+          getOrCreateFor<FooState>(block, pred->getTerminator());
+      result |= state->join(*predState);
+    }
+    return result;
+  });
 }
 
 void FooAnalysis::visitOperation(Operation *op) {
-  FooState *state = getOrCreate<FooState>(op);
-  ChangeResult result = ChangeResult::NoChange;
-
-  // Copy the state across the operation.
-  const FooState *prevState;
-  if (Operation *prev = op->getPrevNode())
-    prevState = getOrCreateFor<FooState>(op, prev);
-  else
-    prevState = getOrCreateFor<FooState>(op, op->getBlock());
-  result |= state->set(*prevState);
-
-  // Modify the state with the attribute, if specified.
-  if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
-    uint64_t value = attr.getUInt();
-    result |= state->join(value);
-  }
-  propagateIfChanged(state, result);
+  update<FooState>(op, [&](FooState *state) {
+    ChangeResult result = ChangeResult::NoChange;
+
+    // Copy the state across the operation.
+    const FooState *prevState;
+    if (Operation *prev = op->getPrevNode())
+      prevState = getOrCreateFor<FooState>(op, prev);
+    else
+      prevState = getOrCreateFor<FooState>(op, op->getBlock());
+    result |= state->set(*prevState);
+
+    // Modify the state with the attribute, if specified.
+    if (auto attr = op->getAttrOfType<IntegerAttr>("foo")) {
+      uint64_t value = attr.getUInt();
+      result |= state->join(value);
+    }
+    return result;
+  });
 }
 
 void TestFooAnalysisPass::runOnOperation() {
@@ -175,7 +176,7 @@ void TestFooAnalysisPass::runOnOperation() {
     auto tag = op->getAttrOfType<StringAttr>("tag");
     if (!tag)
       return;
-    const FooState *state = solver.lookupState<FooState>(op);
+    const FooState *state = solver.lookup<FooState>(op);
     assert(state && !state->isUninitialized());
     os << tag.getValue() << " -> " << state->getValue() << "\n";
   });

diff  --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
index 05e908aae311..f5bfff8e6e1a 100644
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
@@ -23,23 +23,15 @@ using namespace mlir::dataflow;
 /// Patterned after SCCP
 static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
                                          OperationFolder &folder, Value value) {
-  auto *maybeInferredRange =
-      solver.lookupState<IntegerValueRangeLattice>(value);
-  if (!maybeInferredRange || maybeInferredRange->isUninitialized())
-    return failure();
-  const ConstantIntRanges &inferredRange =
-      maybeInferredRange->getValue().getValue();
-  Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
-  if (!maybeConstValue.hasValue())
+  auto *constantState = solver.lookup<ConstantValueState>(value);
+  if (!constantState || constantState->isUninitialized() ||
+      !constantState->getValue().getConstantValue())
     return failure();
 
-  Operation *maybeDefiningOp = value.getDefiningOp();
-  Dialect *valueDialect =
-      maybeDefiningOp ? maybeDefiningOp->getDialect()
-                      : value.getParentRegion()->getParentOp()->getDialect();
-  Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
-  Value constant = folder.getOrCreateConstant(b, valueDialect, constAttr,
-                                              value.getType(), value.getLoc());
+  const ConstantValue &constantValue = constantState->getValue();
+  Value constant = folder.getOrCreateConstant(
+      b, constantValue.getConstantDialect(), constantValue.getConstantValue(),
+      value.getType(), value.getLoc());
   if (!constant)
     return failure();
 
@@ -106,6 +98,7 @@ struct TestIntRangeInference
     DataFlowSolver solver;
     solver.load<DeadCodeAnalysis>();
     solver.load<IntegerRangeAnalysis>();
+    solver.load<IntegerRangeToConstant>();
     if (failed(solver.initializeAndRun(op)))
       return signalPassFailure();
     rewrite(solver, op->getContext(), op->getRegions());


        


More information about the llvm-branch-commits mailing list