[Mlir-commits] [mlir] de0ebc5 - [mlir][dataflow] Consolidate AbstractSparseLattice::markPessimisticFixpoint() and AbstractDenseLattice::reset() into Abstract{Sparse, Dense}DataFlowAnalysis::setToEntryState().
Jeff Niu
llvmlistbot at llvm.org
Mon Aug 29 09:01:06 PDT 2022
Author: Zhixun Tan
Date: 2022-08-29T09:00:55-07:00
New Revision: de0ebc5263f391519f3909c4f436afdf8adbc1ad
URL: https://github.com/llvm/llvm-project/commit/de0ebc5263f391519f3909c4f436afdf8adbc1ad
DIFF: https://github.com/llvm/llvm-project/commit/de0ebc5263f391519f3909c4f436afdf8adbc1ad.diff
LOG: [mlir][dataflow] Consolidate AbstractSparseLattice::markPessimisticFixpoint() and AbstractDenseLattice::reset() into Abstract{Sparse,Dense}DataFlowAnalysis::setToEntryState().
### Rationale
For a program point where we cannot reason about incoming dataflow (e.g. an argument of an entry block), the framework needs to initialize the state.
Currently, `AbstractSparseDataFlowAnalysis` initializes such state to the "pessimistic fixpoint", and `AbstractDenseDataFlowAnalysis` calls the state's `reset()` function.
However, entry states aren't necessarily the pessimistic fixpoint. Example: in reaching definition, the pessimistic fixpoint is `{all definitions}`, but the entry state is `{}`.
This awkwardness might be why the dense analysis API currently uses `reset()` instead of `markPessimisticFixpoint()`.
This patch consolidates entry point initialization into a single function `setToEntryState()`.
### API Location
Note that `setToEntryState()` is defined in the analysis rather than the lattice, so that we allow different analyses to use the same lattice but different entry states.
### Removal of the concept of optimistic/known value
The concept of optimistic/known value is too specific to SCCP.
Furthermore, the known value is not really used: In the current SCCP implementation, the known value (pessimistic fixpoint) is always `Attribute{}` (non-constant). This means there's no point storing a `knownValue` in each state.
If we do need to re-introduce optimistic/known value, we should put it in the SCCP analysis, not the sparse analysis API.
### Terminology
Please let me know if "entry state" is a good terminology.
I chose "entry" from Wikipedia (https://en.wikipedia.org/wiki/Data-flow_analysis#Basic_principles).
Another term I can think of is "boundary" (https://suif.stanford.edu/~courses/cs243/lectures/L3-DFA2-revised.pdf) which might be better since it also makes sense for backward analysis.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D132086
Added:
Modified:
mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/include/mlir/Analysis/DataFlowFramework.h
mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
index 145170f74e546..0935d8b2da5c8 100644
--- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
@@ -46,8 +46,8 @@ class ConstantValue {
/// Print the constant value.
void print(raw_ostream &os) const;
- /// The pessimistic value state of the constant value is unknown.
- static ConstantValue getPessimisticValueState(Value value) { return {}; }
+ /// The state where the constant value is unknown.
+ static ConstantValue getUnknownConstant() { return {}; }
/// The union with another constant value is null if they are
diff erent, and
/// the same if they are the same.
@@ -79,6 +79,8 @@ class SparseConstantPropagation
void visitOperation(Operation *op,
ArrayRef<const Lattice<ConstantValue> *> operands,
ArrayRef<Lattice<ConstantValue> *> results) override;
+
+ void setToEntryState(Lattice<ConstantValue> *lattice) override;
};
} // end namespace dataflow
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 1e66a4ac59c1d..32733a127ad35 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -38,10 +38,6 @@ class AbstractDenseLattice : public AnalysisState {
/// Join the lattice across control-flow or callgraph edges.
virtual ChangeResult join(const AbstractDenseLattice &rhs) = 0;
-
- /// Reset the dense lattice to a pessimistic value. This occurs when the
- /// analysis cannot reason about the data-flow.
- virtual ChangeResult reset() = 0;
};
//===----------------------------------------------------------------------===//
@@ -88,11 +84,9 @@ class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis {
const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent,
ProgramPoint point);
- /// Mark the dense lattice as having reached its pessimistic fixpoint and
- /// propagate an update if it changed.
- void reset(AbstractDenseLattice *lattice) {
- propagateIfChanged(lattice, lattice->reset());
- }
+ /// Set the dense lattice at control flow entry point and propagate an update
+ /// if it changed.
+ virtual void setToEntryState(AbstractDenseLattice *lattice) = 0;
/// Join a lattice with another and propagate an update if it changed.
void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs) {
@@ -147,6 +141,13 @@ class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis {
return getOrCreate<LatticeT>(point);
}
+ /// Set the dense lattice at control flow entry point and propagate an update
+ /// if it changed.
+ virtual void setToEntryState(LatticeT *lattice) = 0;
+ void setToEntryState(AbstractDenseLattice *lattice) override {
+ setToEntryState(static_cast<LatticeT *>(lattice));
+ }
+
private:
/// Type-erased wrappers that convert the abstract dense lattice to a derived
/// lattice and invoke the virtual hooks operating on the derived lattice.
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
index 3cd007ab478ba..4cacc57477aa7 100644
--- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -27,7 +27,7 @@ class IntegerValueRange {
/// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
/// range that is used to mark the value as unable to be analyzed further,
/// where `t` is the type of `value`.
- static IntegerValueRange getPessimisticValueState(Value value);
+ static IntegerValueRange getMaxRange(Value value);
/// Create an integer value range lattice value.
IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
@@ -74,6 +74,12 @@ class IntegerRangeAnalysis
public:
using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
+ /// At an entry point, we cannot reason about interger value ranges.
+ void setToEntryState(IntegerValueRangeLattice *lattice) override {
+ propagateIfChanged(lattice, lattice->join(IntegerValueRange::getMaxRange(
+ lattice->getPoint())));
+ }
+
/// Visit an operation. Invoke the transfer function on each operation that
/// implements `InferIntRangeInterface`.
void visitOperation(Operation *op,
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 480af8b4320f5..dc4dd0978fdb7 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -34,15 +34,13 @@ class AbstractSparseLattice : public AnalysisState {
/// Lattices can only be created for values.
AbstractSparseLattice(Value value) : AnalysisState(value) {}
+ /// Return the program point this lattice is located at.
+ Value getPoint() const { return AnalysisState::getPoint().get<Value>(); }
+
/// Join the information contained in 'rhs' into this lattice. Returns
/// if the value of the lattice changed.
virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0;
- /// Mark the lattice element as having reached a pessimistic fixpoint. This
- /// means that the lattice may potentially have conflicting value states, and
- /// only the most conservative value should be relied on.
- virtual ChangeResult markPessimisticFixpoint() = 0;
-
/// When the lattice gets updated, propagate an update to users of the value
/// using its use-def chain to subscribed analyses.
void onUpdate(DataFlowSolver *solver) const override;
@@ -76,23 +74,23 @@ class AbstractSparseLattice : public AnalysisState {
template <typename ValueT>
class Lattice : public AbstractSparseLattice {
public:
- /// Construct a lattice with a known value.
- explicit Lattice(Value value)
- : AbstractSparseLattice(value),
- knownValue(ValueT::getPessimisticValueState(value)) {}
+ using AbstractSparseLattice::AbstractSparseLattice;
+
+ /// Return the program point this lattice is located at.
+ Value getPoint() const { return point.get<Value>(); }
/// Return the value held by this lattice. This requires that the value is
/// initialized.
ValueT &getValue() {
assert(!isUninitialized() && "expected known lattice element");
- return *optimisticValue;
+ return *value;
}
const ValueT &getValue() const {
return const_cast<Lattice<ValueT> *>(this)->getValue();
}
/// Returns true if the value of this lattice hasn't yet been initialized.
- bool isUninitialized() const override { return !optimisticValue.has_value(); }
+ bool isUninitialized() const override { return !value.has_value(); }
/// Join the information contained in the 'rhs' lattice into this
/// lattice. Returns if the state of the current lattice changed.
@@ -113,56 +111,37 @@ class Lattice : public AbstractSparseLattice {
ChangeResult join(const ValueT &rhs) {
// If the current lattice is uninitialized, copy the rhs value.
if (isUninitialized()) {
- optimisticValue = rhs;
+ value = rhs;
return ChangeResult::Change;
}
// Otherwise, join rhs with the current optimistic value.
- ValueT newValue = ValueT::join(*optimisticValue, rhs);
- assert(ValueT::join(newValue, *optimisticValue) == newValue &&
+ ValueT newValue = ValueT::join(*value, rhs);
+ assert(ValueT::join(newValue, *value) == newValue &&
"expected `join` to be monotonic");
assert(ValueT::join(newValue, rhs) == newValue &&
"expected `join` to be monotonic");
// Update the current optimistic value if something changed.
- if (newValue == optimisticValue)
+ if (newValue == value)
return ChangeResult::NoChange;
- optimisticValue = newValue;
- return ChangeResult::Change;
- }
-
- /// Mark the lattice element as having reached a pessimistic fixpoint. This
- /// means that the lattice may potentially have conflicting value states,
- /// and only the conservatively known value state should be relied on.
- ChangeResult markPessimisticFixpoint() override {
- if (optimisticValue == knownValue)
- return ChangeResult::NoChange;
-
- // For this fixed point, we take whatever we knew to be true and set that
- // to our optimistic value.
- optimisticValue = knownValue;
+ value = newValue;
return ChangeResult::Change;
}
/// Print the lattice element.
void print(raw_ostream &os) const override {
- os << "[";
- knownValue.print(os);
- os << ", ";
- if (optimisticValue)
- optimisticValue->print(os);
+ if (value)
+ value->print(os);
else
os << "<NULL>";
- os << "]";
}
private:
- /// The value that is conservatively known to be true.
- ValueT knownValue;
/// The currently computed value that is optimistically assumed to be true,
/// or None if the lattice element is uninitialized.
- Optional<ValueT> optimisticValue;
+ Optional<ValueT> value;
};
//===----------------------------------------------------------------------===//
@@ -213,9 +192,9 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
Value value);
- /// Mark the given lattice elements as having reached their pessimistic
- /// fixpoints and propagate an update if any changed.
- void markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice *> lattices);
+ /// Set the given lattice element(s) at control flow entry point(s).
+ virtual void setToEntryState(AbstractSparseLattice *lattice) = 0;
+ void setAllToEntryStates(ArrayRef<AbstractSparseLattice *> lattices);
/// Join the lattice element and propagate and update if it changed.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
@@ -278,8 +257,8 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
const RegionSuccessor &successor,
ArrayRef<StateT *> argLattices,
unsigned firstIndex) {
- markAllPessimisticFixpoint(argLattices.take_front(firstIndex));
- markAllPessimisticFixpoint(argLattices.drop_front(
+ setAllToEntryStates(argLattices.take_front(firstIndex));
+ setAllToEntryStates(argLattices.drop_front(
firstIndex + successor.getSuccessorInputs().size()));
}
@@ -296,10 +275,10 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
AbstractSparseDataFlowAnalysis::getLatticeElementFor(point, value));
}
- /// Mark the lattice elements of a range of values as having reached their
- /// pessimistic fixpoint.
- void markAllPessimisticFixpoint(ArrayRef<StateT *> lattices) {
- AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
+ /// Set the given lattice element(s) at control flow entry point(s).
+ virtual void setToEntryState(StateT *lattice) = 0;
+ void setAllToEntryStates(ArrayRef<StateT *> lattices) {
+ AbstractSparseDataFlowAnalysis::setAllToEntryStates(
{reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
lattices.size()});
}
@@ -327,6 +306,9 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
argLattices.size()},
firstIndex);
}
+ void setToEntryState(AbstractSparseLattice *lattice) override {
+ return setToEntryState(reinterpret_cast<StateT *>(lattice));
+ }
};
} // end namespace dataflow
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 37603e5eacef8..e6805ed78cfc1 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -288,6 +288,9 @@ class AnalysisState {
/// Create the analysis state at the given program point.
AnalysisState(ProgramPoint point) : point(point) {}
+ /// Returns the program point this static is located at.
+ ProgramPoint getPoint() const { return point; }
+
/// Returns true if the analysis state is uninitialized.
virtual bool isUninitialized() const = 0;
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index 19d0b7d0f5f70..839c55d054d3c 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -39,7 +39,7 @@ void SparseConstantPropagation::visitOperation(
// folds as the desire here is for simulated execution, and not general
// folding.
if (op->getNumRegions()) {
- markAllPessimisticFixpoint(results);
+ setAllToEntryStates(results);
return;
}
@@ -59,7 +59,7 @@ void SparseConstantPropagation::visitOperation(
SmallVector<OpFoldResult, 8> foldResults;
foldResults.reserve(op->getNumResults());
if (failed(op->fold(constantOperands, foldResults))) {
- markAllPessimisticFixpoint(results);
+ setAllToEntryStates(results);
return;
}
@@ -69,7 +69,7 @@ void SparseConstantPropagation::visitOperation(
if (foldResults.empty()) {
op->setOperands(originalOperands);
op->setAttrs(originalAttrs);
- markAllPessimisticFixpoint(results);
+ setAllToEntryStates(results);
return;
}
@@ -92,3 +92,9 @@ void SparseConstantPropagation::visitOperation(
}
}
}
+
+void SparseConstantPropagation::setToEntryState(
+ Lattice<ConstantValue> *lattice) {
+ propagateIfChanged(lattice,
+ lattice->join(ConstantValue::getUnknownConstant()));
+}
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index 659ae3c38a942..55e1cb1ff95c3 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -62,7 +62,7 @@ void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) {
// If not all return sites are known, then conservatively assume we can't
// reason about the data-flow.
if (!predecessors->allPredecessorsKnown())
- return reset(after);
+ return setToEntryState(after);
for (Operation *predecessor : predecessors->getKnownPredecessors())
join(after, *getLatticeFor(op, predecessor));
return;
@@ -100,7 +100,7 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints.
if (!callsites->allPredecessorsKnown())
- return reset(after);
+ return setToEntryState(after);
for (Operation *callsite : callsites->getKnownPredecessors()) {
// Get the dense lattice before the callsite.
if (Operation *prev = callsite->getPrevNode())
@@ -116,7 +116,7 @@ void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) {
return visitRegionBranchOperation(block, branch, after);
// Otherwise, we can't reason about the data-flow.
- return reset(after);
+ return setToEntryState(after);
}
// Join the state with the state after the block's predecessors.
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index 0147da813dfb7..316097d052938 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -23,7 +23,7 @@
using namespace mlir;
using namespace mlir::dataflow;
-IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
+IntegerValueRange IntegerValueRange::getMaxRange(Value value) {
unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
APInt umin = APInt::getMinValue(width);
APInt umax = APInt::getMaxValue(width);
@@ -41,7 +41,8 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
auto value = point.get<Value>();
auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
if (!constant)
- return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint());
+ return solver->propagateIfChanged(
+ cv, cv->join(ConstantValue::getUnknownConstant()));
Dialect *dialect;
if (auto *parent = value.getDefiningOp())
@@ -60,11 +61,13 @@ void IntegerRangeAnalysis::visitOperation(
// integer results
bool hasIntegerResult = false;
for (auto it : llvm::zip(results, op->getResults())) {
- if (std::get<1>(it).getType().isIntOrIndex()) {
+ Value value = std::get<1>(it);
+ if (value.getType().isIntOrIndex()) {
hasIntegerResult = true;
} else {
- propagateIfChanged(std::get<0>(it),
- std::get<0>(it)->markPessimisticFixpoint());
+ IntegerValueRangeLattice *lattice = std::get<0>(it);
+ propagateIfChanged(lattice,
+ lattice->join(IntegerValueRange::getMaxRange(value)));
}
}
if (!hasIntegerResult)
@@ -72,7 +75,7 @@ void IntegerRangeAnalysis::visitOperation(
auto inferrable = dyn_cast<InferIntRangeInterface>(op);
if (!inferrable)
- return markAllPessimisticFixpoint(results);
+ return setAllToEntryStates(results);
LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
SmallVector<ConstantIntRanges> argRanges(
@@ -104,7 +107,7 @@ void IntegerRangeAnalysis::visitOperation(
if (isYieldedResult && oldRange.has_value() &&
!(lattice->getValue() == *oldRange)) {
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
- changed |= lattice->markPessimisticFixpoint();
+ changed |= lattice->join(IntegerValueRange::getMaxRange(v));
}
propagateIfChanged(lattice, changed);
};
@@ -146,7 +149,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
});
if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
- changed |= lattice->markPessimisticFixpoint();
+ changed |= lattice->join(IntegerValueRange::getMaxRange(v));
}
propagateIfChanged(lattice, changed);
};
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index f85d8dfb279c6..b565b310bf9f7 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -41,7 +41,7 @@ LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) {
if (region.empty())
continue;
for (Value argument : region.front().getArguments())
- markAllPessimisticFixpoint(getLatticeElement(argument));
+ setAllToEntryStates(getLatticeElement(argument));
}
return initializeRecursively(top);
@@ -104,7 +104,7 @@ void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
// If not all return sites are known, then conservatively assume we can't
// reason about the data-flow.
if (!predecessors->allPredecessorsKnown())
- return markAllPessimisticFixpoint(resultLattices);
+ return setAllToEntryStates(resultLattices);
for (Operation *predecessor : predecessors->getKnownPredecessors())
for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
@@ -154,7 +154,7 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints.
if (!callsites->allPredecessorsKnown())
- return markAllPessimisticFixpoint(argLattices);
+ return setAllToEntryStates(argLattices);
for (Operation *callsite : callsites->getKnownPredecessors()) {
auto call = cast<CallOpInterface>(callsite);
for (auto it : llvm::zip(call.getArgOperands(), argLattices))
@@ -197,13 +197,13 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
if (Value operand = operands[it.index()]) {
join(it.value(), *getLatticeElementFor(block, operand));
} else {
- // Conservatively mark internally produced arguments as having reached
- // their pessimistic fixpoint.
- markAllPessimisticFixpoint(it.value());
+ // Conservatively consider internally produced arguments as entry
+ // points.
+ setAllToEntryStates(it.value());
}
}
} else {
- return markAllPessimisticFixpoint(argLattices);
+ return setAllToEntryStates(argLattices);
}
}
}
@@ -231,7 +231,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
if (!operands) {
// We can't reason about the data-flow.
- return markAllPessimisticFixpoint(lattices);
+ return setAllToEntryStates(lattices);
}
ValueRange inputs = predecessors->getSuccessorInputs(op);
@@ -273,10 +273,10 @@ AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
return state;
}
-void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
+void AbstractSparseDataFlowAnalysis::setAllToEntryStates(
ArrayRef<AbstractSparseLattice *> lattices) {
for (AbstractSparseLattice *lattice : lattices)
- propagateIfChanged(lattice, lattice->markPessimisticFixpoint());
+ setToEntryState(lattice);
}
void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
index 27e994cce3b64..90973af9c2cf5 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
@@ -84,19 +84,18 @@ struct ConstantAnalysis : public DataFlowAnalysis {
constant, constant->join(ConstantValue(value, op->getDialect())));
return success();
}
- markAllPessimisticFixpoint(op->getResults());
+ setAllToUnknownConstants(op->getResults());
for (Region ®ion : op->getRegions())
- markAllPessimisticFixpoint(region.getArguments());
+ setAllToUnknownConstants(region.getArguments());
return success();
}
- /// Mark the constant values of all given values as having reached a
- /// pessimistic fixpoint.
- void markAllPessimisticFixpoint(ValueRange values) {
+ /// Set all given values as not constants.
+ void setAllToUnknownConstants(ValueRange values) {
for (Value value : values) {
- auto *constantValue = getOrCreate<Lattice<ConstantValue>>(value);
- propagateIfChanged(constantValue,
- constantValue->markPessimisticFixpoint());
+ auto *constant = getOrCreate<Lattice<ConstantValue>>(value);
+ propagateIfChanged(constant,
+ constant->join(ConstantValue::getUnknownConstant()));
}
}
};
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
index 0fc80a6f70a37..c9eec64e75cdc 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp
@@ -20,14 +20,8 @@ namespace {
/// This lattice represents a single underlying value for an SSA value.
class UnderlyingValue {
public:
- /// The pessimistic underlying value of a value is itself.
- static UnderlyingValue getPessimisticValueState(Value value) {
- return {value};
- }
-
/// Create an underlying value state with a known underlying value.
- UnderlyingValue(Value underlyingValue = {})
- : underlyingValue(underlyingValue) {}
+ UnderlyingValue(Value underlyingValue) : underlyingValue(underlyingValue) {}
/// Returns the underlying value.
Value getUnderlyingValue() const { return underlyingValue; }
@@ -36,7 +30,7 @@ class UnderlyingValue {
/// go to the pessimistic value.
static UnderlyingValue join(const UnderlyingValue &lhs,
const UnderlyingValue &rhs) {
- return lhs.underlyingValue == rhs.underlyingValue ? lhs : UnderlyingValue();
+ return lhs.underlyingValue == rhs.underlyingValue ? lhs : Value();
}
/// Compare underlying values.
@@ -61,9 +55,8 @@ class LastModification : public AbstractDenseLattice {
/// The lattice is always initialized.
bool isUninitialized() const override { return false; }
- /// Mark the lattice as having reached its pessimistic fixpoint. That is, the
- /// last modifications of all memory resources are unknown.
- ChangeResult reset() override {
+ /// Clear all modifications.
+ ChangeResult reset() {
if (lastMods.empty())
return ChangeResult::NoChange;
lastMods.clear();
@@ -131,6 +124,12 @@ class LastModifiedAnalysis : public DenseDataFlowAnalysis<LastModification> {
/// resource, then its reaching definition is set to the written value.
void visitOperation(Operation *op, const LastModification &before,
LastModification *after) override;
+
+ /// At an entry point, the last modifications of all memory resources are
+ /// unknown.
+ void setToEntryState(LastModification *lattice) override {
+ propagateIfChanged(lattice, lattice->reset());
+ }
};
/// Define the lattice class explicitly to provide a type ID.
@@ -152,7 +151,13 @@ class UnderlyingValueAnalysis
void visitOperation(Operation *op,
ArrayRef<const UnderlyingValueLattice *> operands,
ArrayRef<UnderlyingValueLattice *> results) override {
- markAllPessimisticFixpoint(results);
+ setAllToEntryStates(results);
+ }
+
+ /// At an entry point, the underlying value of a value is itself.
+ void setToEntryState(UnderlyingValueLattice *lattice) override {
+ propagateIfChanged(lattice,
+ lattice->join(UnderlyingValue{lattice->getPoint()}));
}
};
} // end anonymous namespace
@@ -181,7 +186,7 @@ void LastModifiedAnalysis::visitOperation(Operation *op,
// If we can't reason about the memory effects, then conservatively assume we
// can't deduce anything about the last modifications.
if (!memory)
- return reset(after);
+ return setToEntryState(after);
SmallVector<MemoryEffects::EffectInstance> effects;
memory.getEffects(effects);
@@ -193,7 +198,7 @@ void LastModifiedAnalysis::visitOperation(Operation *op,
// If we see an effect on anything other than a value, assume we can't
// deduce anything about the last modifications.
if (!value)
- return reset(after);
+ return setToEntryState(after);
value = getMostUnderlyingValue(value, [&](Value value) {
return getOrCreateFor<UnderlyingValueLattice>(op, value);
More information about the Mlir-commits
mailing list