[Mlir-commits] [mlir] de0ebc5 - [mlir][dataflow] Consolidate AbstractSparseLattice::markPessimisticFixpoint() and AbstractDenseLattice::reset() into Abstract{Sparse, Dense}DataFlowAnalysis::setToEntryState().

Jeff Niu llvmlistbot at llvm.org
Mon Aug 29 09:01:06 PDT 2022


Author: Zhixun Tan
Date: 2022-08-29T09:00:55-07:00
New Revision: de0ebc5263f391519f3909c4f436afdf8adbc1ad

URL: https://github.com/llvm/llvm-project/commit/de0ebc5263f391519f3909c4f436afdf8adbc1ad
DIFF: https://github.com/llvm/llvm-project/commit/de0ebc5263f391519f3909c4f436afdf8adbc1ad.diff

LOG: [mlir][dataflow] Consolidate AbstractSparseLattice::markPessimisticFixpoint() and AbstractDenseLattice::reset() into Abstract{Sparse,Dense}DataFlowAnalysis::setToEntryState().

### Rationale

For a program point where we cannot reason about incoming dataflow (e.g. an argument of an entry block), the framework needs to initialize the state.

Currently, `AbstractSparseDataFlowAnalysis` initializes such state to the "pessimistic fixpoint", and `AbstractDenseDataFlowAnalysis` calls the state's `reset()` function.

However, entry states aren't necessarily the pessimistic fixpoint. Example: in reaching definition, the pessimistic fixpoint is `{all definitions}`, but the entry state is `{}`.

This awkwardness might be why the dense analysis API currently uses `reset()` instead of `markPessimisticFixpoint()`.

This patch consolidates entry point initialization into a single function `setToEntryState()`.

### API Location

Note that `setToEntryState()` is defined in the analysis rather than the lattice, so that we allow different analyses to use the same lattice but different entry states.

### Removal of the concept of optimistic/known value

The concept of optimistic/known value is too specific to SCCP.

Furthermore, the known value is not really used: In the current SCCP implementation, the known value (pessimistic fixpoint) is always `Attribute{}` (non-constant). This means there's no point storing a `knownValue` in each state.

If we do need to re-introduce optimistic/known value, we should put it in the SCCP analysis, not the sparse analysis API.

### Terminology

Please let me know if "entry state" is a good terminology.

I chose "entry" from Wikipedia (https://en.wikipedia.org/wiki/Data-flow_analysis#Basic_principles).

Another term I can think of is "boundary" (https://suif.stanford.edu/~courses/cs243/lectures/L3-DFA2-revised.pdf) which might be better since it also makes sense for backward analysis.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D132086

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
    mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
    mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
    mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
    mlir/include/mlir/Analysis/DataFlowFramework.h
    mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
    mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
    mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
    mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
    mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
    mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
index 145170f74e546..0935d8b2da5c8 100644
--- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
@@ -46,8 +46,8 @@ class ConstantValue {
   /// Print the constant value.
   void print(raw_ostream &os) const;
 
-  /// The pessimistic value state of the constant value is unknown.
-  static ConstantValue getPessimisticValueState(Value value) { return {}; }
+  /// The state where the constant value is unknown.
+  static ConstantValue getUnknownConstant() { return {}; }
 
   /// The union with another constant value is null if they are 
diff erent, and
   /// the same if they are the same.
@@ -79,6 +79,8 @@ class SparseConstantPropagation
   void visitOperation(Operation *op,
                       ArrayRef<const Lattice<ConstantValue> *> operands,
                       ArrayRef<Lattice<ConstantValue> *> results) override;
+
+  void setToEntryState(Lattice<ConstantValue> *lattice) override;
 };
 
 } // end namespace dataflow

diff  --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 1e66a4ac59c1d..32733a127ad35 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -38,10 +38,6 @@ class AbstractDenseLattice : public AnalysisState {
 
   /// Join the lattice across control-flow or callgraph edges.
   virtual ChangeResult join(const AbstractDenseLattice &rhs) = 0;
-
-  /// Reset the dense lattice to a pessimistic value. This occurs when the
-  /// analysis cannot reason about the data-flow.
-  virtual ChangeResult reset() = 0;
 };
 
 //===----------------------------------------------------------------------===//
@@ -88,11 +84,9 @@ class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
   const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
                                             ProgramPoint point);
 
-  /// Mark the dense lattice as having reached its pessimistic fixpoint and
-  /// propagate an update if it changed.
-  void reset(AbstractDenseLattice *lattice) {
-    propagateIfChanged(lattice, lattice->reset());
-  }
+  /// Set the dense lattice at control flow entry point and propagate an update
+  /// if it changed.
+  virtual void setToEntryState(AbstractDenseLattice *lattice) = 0;
 
   /// Join a lattice with another and propagate an update if it changed.
   void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs) {
@@ -147,6 +141,13 @@ class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis {
     return getOrCreate<LatticeT>(point);
   }
 
+  /// Set the dense lattice at control flow entry point and propagate an update
+  /// if it changed.
+  virtual void setToEntryState(LatticeT *lattice) = 0;
+  void setToEntryState(AbstractDenseLattice *lattice) override {
+    setToEntryState(static_cast<LatticeT *>(lattice));
+  }
+
 private:
   /// Type-erased wrappers that convert the abstract dense lattice to a derived
   /// lattice and invoke the virtual hooks operating on the derived lattice.

diff  --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 3cd007ab478ba..4cacc57477aa7 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -27,7 +27,7 @@ class IntegerValueRange {
   /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
   /// range that is used to mark the value as unable to be analyzed further,
   /// where `t` is the type of `value`.
-  static IntegerValueRange getPessimisticValueState(Value value);
+  static IntegerValueRange getMaxRange(Value value);
 
   /// Create an integer value range lattice value.
   IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
@@ -74,6 +74,12 @@ class IntegerRangeAnalysis
 public:
   using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
 
+  /// At an entry point, we cannot reason about interger value ranges.
+  void setToEntryState(IntegerValueRangeLattice *lattice) override {
+    propagateIfChanged(lattice, lattice->join(IntegerValueRange::getMaxRange(
+                                    lattice->getPoint())));
+  }
+
   /// Visit an operation. Invoke the transfer function on each operation that
   /// implements `InferIntRangeInterface`.
   void visitOperation(Operation *op,

diff  --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 480af8b4320f5..dc4dd0978fdb7 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -34,15 +34,13 @@ class AbstractSparseLattice : public AnalysisState {
   /// Lattices can only be created for values.
   AbstractSparseLattice(Value value) : AnalysisState(value) {}
 
+  /// Return the program point this lattice is located at.
+  Value getPoint() const { return AnalysisState::getPoint().get<Value>(); }
+
   /// Join the information contained in 'rhs' into this lattice. Returns
   /// if the value of the lattice changed.
   virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0;
 
-  /// Mark the lattice element as having reached a pessimistic fixpoint. This
-  /// means that the lattice may potentially have conflicting value states, and
-  /// only the most conservative value should be relied on.
-  virtual ChangeResult markPessimisticFixpoint() = 0;
-
   /// When the lattice gets updated, propagate an update to users of the value
   /// using its use-def chain to subscribed analyses.
   void onUpdate(DataFlowSolver *solver) const override;
@@ -76,23 +74,23 @@ class AbstractSparseLattice : public AnalysisState {
 template <typename ValueT>
 class Lattice : public AbstractSparseLattice {
 public:
-  /// Construct a lattice with a known value.
-  explicit Lattice(Value value)
-      : AbstractSparseLattice(value),
-        knownValue(ValueT::getPessimisticValueState(value)) {}
+  using AbstractSparseLattice::AbstractSparseLattice;
+
+  /// Return the program point this lattice is located at.
+  Value getPoint() const { return point.get<Value>(); }
 
   /// Return the value held by this lattice. This requires that the value is
   /// initialized.
   ValueT &getValue() {
     assert(!isUninitialized() && "expected known lattice element");
-    return *optimisticValue;
+    return *value;
   }
   const ValueT &getValue() const {
     return const_cast<Lattice<ValueT> *>(this)->getValue();
   }
 
   /// Returns true if the value of this lattice hasn't yet been initialized.
-  bool isUninitialized() const override { return !optimisticValue.has_value(); }
+  bool isUninitialized() const override { return !value.has_value(); }
 
   /// Join the information contained in the 'rhs' lattice into this
   /// lattice. Returns if the state of the current lattice changed.
@@ -113,56 +111,37 @@ class Lattice : public AbstractSparseLattice {
   ChangeResult join(const ValueT &rhs) {
     // If the current lattice is uninitialized, copy the rhs value.
     if (isUninitialized()) {
-      optimisticValue = rhs;
+      value = rhs;
       return ChangeResult::Change;
     }
 
     // Otherwise, join rhs with the current optimistic value.
-    ValueT newValue = ValueT::join(*optimisticValue, rhs);
-    assert(ValueT::join(newValue, *optimisticValue) == newValue &&
+    ValueT newValue = ValueT::join(*value, rhs);
+    assert(ValueT::join(newValue, *value) == newValue &&
            "expected `join` to be monotonic");
     assert(ValueT::join(newValue, rhs) == newValue &&
            "expected `join` to be monotonic");
 
     // Update the current optimistic value if something changed.
-    if (newValue == optimisticValue)
+    if (newValue == value)
       return ChangeResult::NoChange;
 
-    optimisticValue = newValue;
-    return ChangeResult::Change;
-  }
-
-  /// Mark the lattice element as having reached a pessimistic fixpoint. This
-  /// means that the lattice may potentially have conflicting value states,
-  /// and only the conservatively known value state should be relied on.
-  ChangeResult markPessimisticFixpoint() override {
-    if (optimisticValue == knownValue)
-      return ChangeResult::NoChange;
-
-    // For this fixed point, we take whatever we knew to be true and set that
-    // to our optimistic value.
-    optimisticValue = knownValue;
+    value = newValue;
     return ChangeResult::Change;
   }
 
   /// Print the lattice element.
   void print(raw_ostream &os) const override {
-    os << "[";
-    knownValue.print(os);
-    os << ", ";
-    if (optimisticValue)
-      optimisticValue->print(os);
+    if (value)
+      value->print(os);
     else
       os << "<NULL>";
-    os << "]";
   }
 
 private:
-  /// The value that is conservatively known to be true.
-  ValueT knownValue;
   /// The currently computed value that is optimistically assumed to be true,
   /// or None if the lattice element is uninitialized.
-  Optional<ValueT> optimisticValue;
+  Optional<ValueT> value;
 };
 
 //===----------------------------------------------------------------------===//
@@ -213,9 +192,9 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
   const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
                                                     Value value);
 
-  /// Mark the given lattice elements as having reached their pessimistic
-  /// fixpoints and propagate an update if any changed.
-  void markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice *> lattices);
+  /// Set the given lattice element(s) at control flow entry point(s).
+  virtual void setToEntryState(AbstractSparseLattice *lattice) = 0;
+  void setAllToEntryStates(ArrayRef<AbstractSparseLattice *> lattices);
 
   /// Join the lattice element and propagate and update if it changed.
   void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
@@ -278,8 +257,8 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
                                             const RegionSuccessor &successor,
                                             ArrayRef<StateT *> argLattices,
                                             unsigned firstIndex) {
-    markAllPessimisticFixpoint(argLattices.take_front(firstIndex));
-    markAllPessimisticFixpoint(argLattices.drop_front(
+    setAllToEntryStates(argLattices.take_front(firstIndex));
+    setAllToEntryStates(argLattices.drop_front(
         firstIndex + successor.getSuccessorInputs().size()));
   }
 
@@ -296,10 +275,10 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
         AbstractSparseDataFlowAnalysis::getLatticeElementFor(point, value));
   }
 
-  /// Mark the lattice elements of a range of values as having reached their
-  /// pessimistic fixpoint.
-  void markAllPessimisticFixpoint(ArrayRef<StateT *> lattices) {
-    AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
+  /// Set the given lattice element(s) at control flow entry point(s).
+  virtual void setToEntryState(StateT *lattice) = 0;
+  void setAllToEntryStates(ArrayRef<StateT *> lattices) {
+    AbstractSparseDataFlowAnalysis::setAllToEntryStates(
         {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
          lattices.size()});
   }
@@ -327,6 +306,9 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
          argLattices.size()},
         firstIndex);
   }
+  void setToEntryState(AbstractSparseLattice *lattice) override {
+    return setToEntryState(reinterpret_cast<StateT *>(lattice));
+  }
 };
 
 } // end namespace dataflow

diff  --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 37603e5eacef8..e6805ed78cfc1 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -288,6 +288,9 @@ class AnalysisState {
   /// Create the analysis state at the given program point.
   AnalysisState(ProgramPoint point) : point(point) {}
 
+  /// Returns the program point this static is located at.
+  ProgramPoint getPoint() const { return point; }
+
   /// Returns true if the analysis state is uninitialized.
   virtual bool isUninitialized() const = 0;
 

diff  --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index 19d0b7d0f5f70..839c55d054d3c 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -39,7 +39,7 @@ void SparseConstantPropagation::visitOperation(
   // folds as the desire here is for simulated execution, and not general
   // folding.
   if (op->getNumRegions()) {
-    markAllPessimisticFixpoint(results);
+    setAllToEntryStates(results);
     return;
   }
 
@@ -59,7 +59,7 @@ void SparseConstantPropagation::visitOperation(
   SmallVector<OpFoldResult, 8> foldResults;
   foldResults.reserve(op->getNumResults());
   if (failed(op->fold(constantOperands, foldResults))) {
-    markAllPessimisticFixpoint(results);
+    setAllToEntryStates(results);
     return;
   }
 
@@ -69,7 +69,7 @@ void SparseConstantPropagation::visitOperation(
   if (foldResults.empty()) {
     op->setOperands(originalOperands);
     op->setAttrs(originalAttrs);
-    markAllPessimisticFixpoint(results);
+    setAllToEntryStates(results);
     return;
   }
 
@@ -92,3 +92,9 @@ void SparseConstantPropagation::visitOperation(
     }
   }
 }
+
+void SparseConstantPropagation::setToEntryState(
+    Lattice<ConstantValue> *lattice) {
+  propagateIfChanged(lattice,
+                     lattice->join(ConstantValue::getUnknownConstant()));
+}

diff  --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 659ae3c38a942..55e1cb1ff95c3 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -62,7 +62,7 @@ void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) {
     // If not all return sites are known, then conservatively assume we can't
     // reason about the data-flow.
     if (!predecessors->allPredecessorsKnown())
-      return reset(after);
+      return setToEntryState(after);
     for (Operation *predecessor : predecessors->getKnownPredecessors())
       join(after, *getLatticeFor(op, predecessor));
     return;
@@ -100,7 +100,7 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
       // If not all callsites are known, conservatively mark all lattices as
       // having reached their pessimistic fixpoints.
       if (!callsites->allPredecessorsKnown())
-        return reset(after);
+        return setToEntryState(after);
       for (Operation *callsite : callsites->getKnownPredecessors()) {
         // Get the dense lattice before the callsite.
         if (Operation *prev = callsite->getPrevNode())
@@ -116,7 +116,7 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
       return visitRegionBranchOperation(block, branch, after);
 
     // Otherwise, we can't reason about the data-flow.
-    return reset(after);
+    return setToEntryState(after);
   }
 
   // Join the state with the state after the block's predecessors.

diff  --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 0147da813dfb7..316097d052938 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -23,7 +23,7 @@
 using namespace mlir;
 using namespace mlir::dataflow;
 
-IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
+IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
   unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
   APInt umin = APInt::getMinValue(width);
   APInt umax = APInt::getMaxValue(width);
@@ -41,7 +41,8 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
   auto value = point.get<Value>();
   auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
   if (!constant)
-    return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint());
+    return solver->propagateIfChanged(
+        cv, cv->join(ConstantValue::getUnknownConstant()));
 
   Dialect *dialect;
   if (auto *parent = value.getDefiningOp())
@@ -60,11 +61,13 @@ void IntegerRangeAnalysis::visitOperation(
   // integer results
   bool hasIntegerResult = false;
   for (auto it : llvm::zip(results, op->getResults())) {
-    if (std::get<1>(it).getType().isIntOrIndex()) {
+    Value value = std::get<1>(it);
+    if (value.getType().isIntOrIndex()) {
       hasIntegerResult = true;
     } else {
-      propagateIfChanged(std::get<0>(it),
-                         std::get<0>(it)->markPessimisticFixpoint());
+      IntegerValueRangeLattice *lattice = std::get<0>(it);
+      propagateIfChanged(lattice,
+                         lattice->join(IntegerValueRange::getMaxRange(value)));
     }
   }
   if (!hasIntegerResult)
@@ -72,7 +75,7 @@ void IntegerRangeAnalysis::visitOperation(
 
   auto inferrable = dyn_cast<InferIntRangeInterface>(op);
   if (!inferrable)
-    return markAllPessimisticFixpoint(results);
+    return setAllToEntryStates(results);
 
   LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
   SmallVector<ConstantIntRanges> argRanges(
@@ -104,7 +107,7 @@ void IntegerRangeAnalysis::visitOperation(
     if (isYieldedResult && oldRange.has_value() &&
         !(lattice->getValue() == *oldRange)) {
       LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
-      changed |= lattice->markPessimisticFixpoint();
+      changed |= lattice->join(IntegerValueRange::getMaxRange(v));
     }
     propagateIfChanged(lattice, changed);
   };
@@ -146,7 +149,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
       });
       if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
         LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
-        changed |= lattice->markPessimisticFixpoint();
+        changed |= lattice->join(IntegerValueRange::getMaxRange(v));
       }
       propagateIfChanged(lattice, changed);
     };

diff  --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index f85d8dfb279c6..b565b310bf9f7 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -41,7 +41,7 @@ LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) {
     if (region.empty())
       continue;
     for (Value argument : region.front().getArguments())
-      markAllPessimisticFixpoint(getLatticeElement(argument));
+      setAllToEntryStates(getLatticeElement(argument));
   }
 
   return initializeRecursively(top);
@@ -104,7 +104,7 @@ void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
     // If not all return sites are known, then conservatively assume we can't
     // reason about the data-flow.
     if (!predecessors->allPredecessorsKnown())
-      return markAllPessimisticFixpoint(resultLattices);
+      return setAllToEntryStates(resultLattices);
     for (Operation *predecessor : predecessors->getKnownPredecessors())
       for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
         join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
@@ -154,7 +154,7 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
       // If not all callsites are known, conservatively mark all lattices as
       // having reached their pessimistic fixpoints.
       if (!callsites->allPredecessorsKnown())
-        return markAllPessimisticFixpoint(argLattices);
+        return setAllToEntryStates(argLattices);
       for (Operation *callsite : callsites->getKnownPredecessors()) {
         auto call = cast<CallOpInterface>(callsite);
         for (auto it : llvm::zip(call.getArgOperands(), argLattices))
@@ -197,13 +197,13 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
         if (Value operand = operands[it.index()]) {
           join(it.value(), *getLatticeElementFor(block, operand));
         } else {
-          // Conservatively mark internally produced arguments as having reached
-          // their pessimistic fixpoint.
-          markAllPessimisticFixpoint(it.value());
+          // Conservatively consider internally produced arguments as entry
+          // points.
+          setAllToEntryStates(it.value());
         }
       }
     } else {
-      return markAllPessimisticFixpoint(argLattices);
+      return setAllToEntryStates(argLattices);
     }
   }
 }
@@ -231,7 +231,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
 
     if (!operands) {
       // We can't reason about the data-flow.
-      return markAllPessimisticFixpoint(lattices);
+      return setAllToEntryStates(lattices);
     }
 
     ValueRange inputs = predecessors->getSuccessorInputs(op);
@@ -273,10 +273,10 @@ AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
   return state;
 }
 
-void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
+void AbstractSparseDataFlowAnalysis::setAllToEntryStates(
     ArrayRef<AbstractSparseLattice *> lattices) {
   for (AbstractSparseLattice *lattice : lattices)
-    propagateIfChanged(lattice, lattice->markPessimisticFixpoint());
+    setToEntryState(lattice);
 }
 
 void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,

diff  --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
index 27e994cce3b64..90973af9c2cf5 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
@@ -84,19 +84,18 @@ struct ConstantAnalysis : public DataFlowAnalysis {
           constant, constant->join(ConstantValue(value, op->getDialect())));
       return success();
     }
-    markAllPessimisticFixpoint(op->getResults());
+    setAllToUnknownConstants(op->getResults());
     for (Region &region : op->getRegions())
-      markAllPessimisticFixpoint(region.getArguments());
+      setAllToUnknownConstants(region.getArguments());
     return success();
   }
 
-  /// Mark the constant values of all given values as having reached a
-  /// pessimistic fixpoint.
-  void markAllPessimisticFixpoint(ValueRange values) {
+  /// Set all given values as not constants.
+  void setAllToUnknownConstants(ValueRange values) {
     for (Value value : values) {
-      auto *constantValue = getOrCreate<Lattice<ConstantValue>>(value);
-      propagateIfChanged(constantValue,
-                         constantValue->markPessimisticFixpoint());
+      auto *constant = getOrCreate<Lattice<ConstantValue>>(value);
+      propagateIfChanged(constant,
+                         constant->join(ConstantValue::getUnknownConstant()));
     }
   }
 };

diff  --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
index 0fc80a6f70a37..c9eec64e75cdc 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
@@ -20,14 +20,8 @@ namespace {
 /// This lattice represents a single underlying value for an SSA value.
 class UnderlyingValue {
 public:
-  /// The pessimistic underlying value of a value is itself.
-  static UnderlyingValue getPessimisticValueState(Value value) {
-    return {value};
-  }
-
   /// Create an underlying value state with a known underlying value.
-  UnderlyingValue(Value underlyingValue = {})
-      : underlyingValue(underlyingValue) {}
+  UnderlyingValue(Value underlyingValue) : underlyingValue(underlyingValue) {}
 
   /// Returns the underlying value.
   Value getUnderlyingValue() const { return underlyingValue; }
@@ -36,7 +30,7 @@ class UnderlyingValue {
   /// go to the pessimistic value.
   static UnderlyingValue join(const UnderlyingValue &lhs,
                               const UnderlyingValue &rhs) {
-    return lhs.underlyingValue == rhs.underlyingValue ? lhs : UnderlyingValue();
+    return lhs.underlyingValue == rhs.underlyingValue ? lhs : Value();
   }
 
   /// Compare underlying values.
@@ -61,9 +55,8 @@ class LastModification : public AbstractDenseLattice {
   /// The lattice is always initialized.
   bool isUninitialized() const override { return false; }
 
-  /// Mark the lattice as having reached its pessimistic fixpoint. That is, the
-  /// last modifications of all memory resources are unknown.
-  ChangeResult reset() override {
+  /// Clear all modifications.
+  ChangeResult reset() {
     if (lastMods.empty())
       return ChangeResult::NoChange;
     lastMods.clear();
@@ -131,6 +124,12 @@ class LastModifiedAnalysis : public DenseDataFlowAnalysis<LastModification> {
   /// resource, then its reaching definition is set to the written value.
   void visitOperation(Operation *op, const LastModification &before,
                       LastModification *after) override;
+
+  /// At an entry point, the last modifications of all memory resources are
+  /// unknown.
+  void setToEntryState(LastModification *lattice) override {
+    propagateIfChanged(lattice, lattice->reset());
+  }
 };
 
 /// Define the lattice class explicitly to provide a type ID.
@@ -152,7 +151,13 @@ class UnderlyingValueAnalysis
   void visitOperation(Operation *op,
                       ArrayRef<const UnderlyingValueLattice *> operands,
                       ArrayRef<UnderlyingValueLattice *> results) override {
-    markAllPessimisticFixpoint(results);
+    setAllToEntryStates(results);
+  }
+
+  /// At an entry point, the underlying value of a value is itself.
+  void setToEntryState(UnderlyingValueLattice *lattice) override {
+    propagateIfChanged(lattice,
+                       lattice->join(UnderlyingValue{lattice->getPoint()}));
   }
 };
 } // end anonymous namespace
@@ -181,7 +186,7 @@ void LastModifiedAnalysis::visitOperation(Operation *op,
   // If we can't reason about the memory effects, then conservatively assume we
   // can't deduce anything about the last modifications.
   if (!memory)
-    return reset(after);
+    return setToEntryState(after);
 
   SmallVector<MemoryEffects::EffectInstance> effects;
   memory.getEffects(effects);
@@ -193,7 +198,7 @@ void LastModifiedAnalysis::visitOperation(Operation *op,
     // If we see an effect on anything other than a value, assume we can't
     // deduce anything about the last modifications.
     if (!value)
-      return reset(after);
+      return setToEntryState(after);
 
     value = getMostUnderlyingValue(value, [&](Value value) {
       return getOrCreateFor<UnderlyingValueLattice>(op, value);


        


More information about the Mlir-commits mailing list