[llvm-branch-commits] [mlir] a14f859 - [mlir] Swap integer range inference to the new framework
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jun 29 10:23:50 PDT 2022
Author: Mogball
Date: 2022-06-27T13:18:24-07:00
New Revision: a14f85985c9f009c4d051f82d92842589ebb2b1b
URL: https://github.com/llvm/llvm-project/commit/a14f85985c9f009c4d051f82d92842589ebb2b1b
DIFF: https://github.com/llvm/llvm-project/commit/a14f85985c9f009c4d051f82d92842589ebb2b1b.diff
LOG: [mlir] Swap integer range inference to the new framework
Integer range inference has been swapped to the new framework. The integer value range lattices automatically updates the corresponding constant value on update.
Added:
Modified:
mlir/include/mlir/Analysis/DataFlowFramework.h
mlir/include/mlir/Analysis/IntRangeAnalysis.h
mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
mlir/include/mlir/Interfaces/InferIntRangeInterface.td
mlir/lib/Analysis/DataFlowFramework.cpp
mlir/lib/Analysis/IntRangeAnalysis.cpp
mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
mlir/lib/Transforms/SCCP.cpp
mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp
mlir/test/lib/Transforms/TestIntRangeInference.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 19d8fc0c3e19b..2992e05f14ddf 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -226,7 +226,6 @@ class DataFlowSolver {
/// Push a work item onto the worklist.
void enqueue(WorkItem item) { worklist.push(std::move(item)); }
-protected:
/// Get the state associated with the given program point. If it does not
/// exist, create an uninitialized state.
template <typename StateT, typename PointT>
diff --git a/mlir/include/mlir/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Analysis/IntRangeAnalysis.h
index b2b604359b48b..76b6b27c05187 100644
--- a/mlir/include/mlir/Analysis/IntRangeAnalysis.h
+++ b/mlir/include/mlir/Analysis/IntRangeAnalysis.h
@@ -15,27 +15,81 @@
#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H
#define MLIR_ANALYSIS_INTRANGEANALYSIS_H
+#include "mlir/Analysis/SparseDataFlowAnalysis.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
namespace mlir {
-namespace detail {
-class IntRangeAnalysisImpl;
-} // end namespace detail
-class IntRangeAnalysis {
+/// This lattice value represents the integer range of an SSA value.
+class IntegerValueRange {
public:
- /// Analyze all operations rooted under (but not including)
- /// `topLevelOperation`.
- IntRangeAnalysis(Operation *topLevelOperation);
- IntRangeAnalysis(IntRangeAnalysis &&other);
- ~IntRangeAnalysis();
+ /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
+ /// range that is used to mark the value as unable to be analyzed further,
+ /// where `t` is the type of `value`.
+ static IntegerValueRange getPessimisticValueState(Value value);
- /// Get inferred range for value `v` if one exists.
- Optional<ConstantIntRanges> getResult(Value v);
+ /// Create an integer value range lattice value.
+ IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
+
+ /// Get the known integer value range.
+ const ConstantIntRanges &getValue() const { return value; }
+
+ /// Compare two ranges.
+ bool operator==(const IntegerValueRange &rhs) const {
+ return value == rhs.value;
+ }
+
+ /// Take the union of two ranges.
+ static IntegerValueRange join(const IntegerValueRange &lhs,
+ const IntegerValueRange &rhs) {
+ return lhs.value.rangeUnion(rhs.value);
+ }
+
+ /// Print the integer value range.
+ void print(raw_ostream &os) const { os << value; }
private:
- std::unique_ptr<detail::IntRangeAnalysisImpl> impl;
+ /// The known integer value range.
+ ConstantIntRanges value;
+};
+
+/// This lattice element represents the integer value range of an SSA value.
+/// When this lattice is updated, it automatically updates the constant value
+/// of the SSA value (if the range can be narrowed to one).
+class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
+public:
+ using Lattice::Lattice;
+
+ /// If the range can be narrowed to an integer constant, update the constant
+ /// value of the SSA value.
+ void onUpdate(DataFlowSolver *solver) const override;
};
+
+/// Integer range analysis determines the integer value range of SSA values
+/// using operations that define `InferIntRangeInterface` and also sets the
+/// range of iteration indices of loops with known bounds.
+class IntegerRangeAnalysis
+ : public SparseDataFlowAnalysis<IntegerValueRangeLattice> {
+public:
+ using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
+
+ /// Visit an operation. Invoke the transfer function on each operation that
+ /// implements `InferIntRangeInterface`.
+ void visitOperation(Operation *op,
+ ArrayRef<const IntegerValueRangeLattice *> operands,
+ ArrayRef<IntegerValueRangeLattice *> results) override;
+
+ /// Visit block arguments or operation results of an operation with region
+ /// control-flow for which values are not defined by region control-flow. This
+ /// function calls `InferIntRangeInterface` to provide values for block
+ /// arguments or tries to reduce the range on loop induction variables with
+ /// known bounds.
+ void
+ visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor,
+ ArrayRef<IntegerValueRangeLattice *> argLattices,
+ unsigned firstIndex) override;
+};
+
} // end namespace mlir
#endif
diff --git a/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
index e18cd6a4d835b..412a0fea767a3 100644
--- a/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
+++ b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
@@ -455,6 +455,14 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
+ /// Given an operation with region control-flow, the lattices of the operands,
+ /// and a region successor, compute the lattice values for block arguments
+ /// that are not accounted for by the branching control flow (ex. the bounds
+ /// of loops).
+ virtual void visitNonControlFlowArgumentsImpl(
+ Operation *op, const RegionSuccessor &successor,
+ ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0;
+
/// Get the lattice element of a value.
virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
@@ -515,6 +523,21 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
ArrayRef<StateT *> results) = 0;
+ /// Given an operation with possible region control-flow, the lattices of the
+ /// operands, and a region successor, compute the lattice values for block
+ /// arguments that are not accounted for by the branching control flow (ex.
+ /// the bounds of loops). By default, this method marks all such lattice
+ /// elements as having reached a pessimistic fixpoint. `firstIndex` is the
+ /// index of the first element of `argLattices` that is set by control-flow.
+ virtual void visitNonControlFlowArguments(Operation *op,
+ const RegionSuccessor &successor,
+ ArrayRef<StateT *> argLattices,
+ unsigned firstIndex) {
+ markAllPessimisticFixpoint(argLattices.take_front(firstIndex));
+ markAllPessimisticFixpoint(argLattices.drop_front(
+ firstIndex + successor.getSuccessorInputs().size()));
+ }
+
protected:
/// Get the lattice element for a value.
StateT *getLatticeElement(Value value) override {
@@ -549,6 +572,16 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
{reinterpret_cast<StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
+ void visitNonControlFlowArgumentsImpl(
+ Operation *op, const RegionSuccessor &successor,
+ ArrayRef<AbstractSparseLattice *> argLattices,
+ unsigned firstIndex) override {
+ visitNonControlFlowArguments(
+ op, successor,
+ {reinterpret_cast<StateT *const *>(argLattices.begin()),
+ argLattices.size()},
+ firstIndex);
+ }
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
index 57f8d693b7916..abe6df1543625 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -30,7 +30,7 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
since the dataflow analysis handles those case), the method should call
`setValueRange` with that `Value` as an argument. When `setValueRange`
is not called for some value, it will recieve a default value of the mimimum
- and maximum values forits type (the unbounded range).
+ and maximum values for its type (the unbounded range).
When called on an op that also implements the RegionBranchOpInterface
or BranchOpInterface, this method should not attempt to infer the values
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index be18432468d4f..18d9ba1bd5d60 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -87,19 +87,6 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
return failure();
}
- // "Nudge" the state of the analysis by forcefully initializing states that
- // are still uninitialized. All uninitialized states in the graph can be
- // initialized in any order because the analysis reached fixpoint, meaning
- // that there are no work items that would have further nudged the analysis.
- for (AnalysisState &state :
- llvm::make_pointee_range(llvm::make_second_range(analysisStates))) {
- if (!state.isUninitialized())
- continue;
- DATAFLOW_DEBUG(llvm::dbgs() << "Default initializing " << state.debugName
- << " of " << state.point << "\n");
- propagateIfChanged(&state, state.defaultInitialize());
- }
-
// Iterate until all states are in some initialized state and the worklist
// is exhausted.
} while (!worklist.empty());
diff --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp
index f887d68d12ec2..b876b372e286f 100644
--- a/mlir/lib/Analysis/IntRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp
@@ -13,7 +13,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/IntRangeAnalysis.h"
-#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "llvm/Support/Debug.h"
@@ -22,244 +21,120 @@
using namespace mlir;
-namespace {
-/// A wrapper around ConstantIntRanges that provides the lattice functions
-/// expected by dataflow analysis.
-struct IntRangeLattice {
- IntRangeLattice(const ConstantIntRanges &value) : value(value){};
- IntRangeLattice(ConstantIntRanges &&value) : value(value){};
-
- bool operator==(const IntRangeLattice &other) const {
- return value == other.value;
- }
-
- /// wrapper around rangeUnion()
- static IntRangeLattice join(const IntRangeLattice &a,
- const IntRangeLattice &b) {
- return a.value.rangeUnion(b.value);
- }
-
- /// Creates a range with bitwidth 0 to represent that we don't know if the
- /// value being marked overdefined is even an integer.
- static IntRangeLattice getPessimisticValueState(MLIRContext *context) {
- APInt noIntValue = APInt::getZeroWidth();
- return ConstantIntRanges(noIntValue, noIntValue, noIntValue, noIntValue);
- }
-
- /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
- /// range that is used to mark the value v as unable to be analyzed further,
- /// where t is the type of v.
- static IntRangeLattice getPessimisticValueState(Value v) {
- unsigned int width = ConstantIntRanges::getStorageBitwidth(v.getType());
- APInt umin = APInt::getMinValue(width);
- APInt umax = APInt::getMaxValue(width);
- APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
- APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
- return ConstantIntRanges{umin, umax, smin, smax};
- }
-
- ConstantIntRanges value;
-};
-} // end anonymous namespace
-
-namespace mlir {
-namespace detail {
-class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis<IntRangeLattice> {
- using ForwardDataFlowAnalysis<IntRangeLattice>::ForwardDataFlowAnalysis;
-
-public:
- /// Define bounds on the results or block arguments of the operation
- /// based on the bounds on the arguments given in `operands`
- ChangeResult
- visitOperation(Operation *op,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
-
- /// Skip regions of branch ops when we can statically infer constant
- /// values for operands to the branch op and said op tells us it's safe to do
- /// so.
- LogicalResult
- getSuccessorsForOperands(BranchOpInterface branch,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands,
- SmallVectorImpl<Block *> &successors) final;
-
- /// Skip regions of branch or loop ops when we can statically infer constant
- /// values for operands to the branch op and said op tells us it's safe to do
- /// so.
- void
- getSuccessorsForOperands(RegionBranchOpInterface branch,
- Optional<unsigned> sourceIndex,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands,
- SmallVectorImpl<RegionSuccessor> &successors) final;
-
- /// Call the InferIntRangeInterface implementation for region-using ops
- /// that implement it, and infer the bounds of loop induction variables
- /// for ops that implement LoopLikeOPInterface.
- ChangeResult visitNonControlFlowArguments(
- Operation *op, const RegionSuccessor ®ion,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
-};
-} // end namespace detail
-} // end namespace mlir
+IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
+ unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
+ APInt umin = APInt::getMinValue(width);
+ APInt umax = APInt::getMaxValue(width);
+ APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
+ APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
+ return {{umin, umax, smin, smax}};
+}
-/// Given the results of getConstant{Lower,Upper}Bound()
-/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for
-/// that result if possible.
-static APInt getLoopBoundFromFold(Optional<OpFoldResult> loopBound,
- Type boundType,
- detail::IntRangeAnalysisImpl &analysis,
- bool getUpper) {
- unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
- if (loopBound) {
- if (loopBound->is<Attribute>()) {
- if (auto bound =
- loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
- return bound.getValue();
- } else if (loopBound->is<Value>()) {
- LatticeElement<IntRangeLattice> *lattice =
- analysis.lookupLatticeElement(loopBound->get<Value>());
- if (lattice != nullptr)
- return getUpper ? lattice->getValue().value.smax()
- : lattice->getValue().value.smin();
- }
- }
- return getUpper ? APInt::getSignedMaxValue(width)
- : APInt::getSignedMinValue(width);
+void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
+ Lattice::onUpdate(solver);
+
+ // If the integer range can be narrowed to a constant, update the constant
+ // value of the SSA value.
+ Optional<APInt> constant = getValue().getValue().getConstantValue();
+ auto value = point.get<Value>();
+ auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
+ if (!constant)
+ return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint());
+
+ Dialect *dialect;
+ if (auto *parent = value.getDefiningOp())
+ dialect = parent->getDialect();
+ else
+ dialect = value.getParentBlock()->getParentOp()->getDialect();
+ solver->propagateIfChanged(
+ cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
+ dialect)));
}
-ChangeResult detail::IntRangeAnalysisImpl::visitOperation(
- Operation *op, ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
- ChangeResult result = ChangeResult::NoChange;
+void IntegerRangeAnalysis::visitOperation(
+ Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
+ ArrayRef<IntegerValueRangeLattice *> results) {
// Ignore non-integer outputs - return early if the op has no scalar
// integer results
bool hasIntegerResult = false;
- for (Value v : op->getResults()) {
- if (v.getType().isIntOrIndex())
+ for (auto it : llvm::zip(results, op->getResults())) {
+ if (std::get<1>(it).getType().isIntOrIndex()) {
hasIntegerResult = true;
- else
- result |= markAllPessimisticFixpoint(v);
- }
- if (!hasIntegerResult)
- return result;
-
- if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
- LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
- LLVM_DEBUG(inferrable->print(llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
- SmallVector<ConstantIntRanges> argRanges(
- llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
- return val->getValue().value;
- }));
-
- auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
- LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
- LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
- Optional<IntRangeLattice> oldRange;
- if (!lattice.isUninitialized())
- oldRange = lattice.getValue();
- result |= lattice.join(IntRangeLattice(attrs));
-
- // Catch loop results with loop variant bounds and conservatively make
- // them [-inf, inf] so we don't circle around infinitely often (because
- // the dataflow analysis in MLIR doesn't attempt to work out trip counts
- // and often can't).
- bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
- return op->hasTrait<OpTrait::IsTerminator>();
- });
- if (isYieldedResult && oldRange && !(lattice.getValue() == *oldRange)) {
- LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
- result |= lattice.markPessimisticFixpoint();
- }
- };
-
- inferrable.inferResultRanges(argRanges, joinCallback);
- for (Value opResult : op->getResults()) {
- LatticeElement<IntRangeLattice> &lattice = getLatticeElement(opResult);
- // setResultRange() not called, make pessimistic.
- if (lattice.isUninitialized())
- result |= lattice.markPessimisticFixpoint();
- }
- } else if (op->getNumRegions() == 0) {
- // No regions + no result inference method -> unbounded results (ex. memory
- // ops)
- result |= markAllPessimisticFixpoint(op->getResults());
- }
- return result;
-}
-
-LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
- BranchOpInterface branch,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands,
- SmallVectorImpl<Block *> &successors) {
- auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
- Optional<APInt> maybeConstValue =
- enumPair.value()->getValue().value.getConstantValue();
-
- if (maybeConstValue) {
- return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
- *maybeConstValue);
+ } else {
+ propagateIfChanged(std::get<0>(it),
+ std::get<0>(it)->markPessimisticFixpoint());
}
- return {};
- };
- SmallVector<Attribute> inferredConsts(
- llvm::map_range(llvm::enumerate(operands), toConstantAttr));
- if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) {
- successors.push_back(singleSucc);
- return success();
}
- return failure();
-}
-
-void detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
- RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands,
- SmallVectorImpl<RegionSuccessor> &successors) {
- // Get a type with which to construct a constant.
- auto getOperandType = [branch, sourceIndex](unsigned index) {
- // The types of all return-like operations are the same.
- if (!sourceIndex)
- return branch->getOperand(index).getType();
-
- for (Block &block : branch->getRegion(*sourceIndex)) {
- Operation *terminator = block.getTerminator();
- if (getRegionBranchSuccessorOperands(terminator, *sourceIndex))
- return terminator->getOperand(index).getType();
+ if (!hasIntegerResult)
+ return;
+
+ auto inferrable = dyn_cast<InferIntRangeInterface>(op);
+ if (!inferrable)
+ return markAllPessimisticFixpoint(results);
+
+ LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
+ SmallVector<ConstantIntRanges> argRanges(
+ llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
+ return val->getValue().getValue();
+ }));
+
+ auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+ auto result = v.dyn_cast<OpResult>();
+ if (!result)
+ return;
+ assert(llvm::find(op->getResults(), result) != op->result_end());
+
+ LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+ IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
+ Optional<IntegerValueRange> oldRange;
+ if (!lattice->isUninitialized())
+ oldRange = lattice->getValue();
+
+ ChangeResult changed = lattice->join(attrs);
+
+ // Catch loop results with loop variant bounds and conservatively make
+ // them [-inf, inf] so we don't circle around infinitely often (because
+ // the dataflow analysis in MLIR doesn't attempt to work out trip counts
+ // and often can't).
+ bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
+ return op->hasTrait<OpTrait::IsTerminator>();
+ });
+ if (isYieldedResult && oldRange.hasValue() &&
+ !(lattice->getValue() == *oldRange)) {
+ LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ changed |= lattice->markPessimisticFixpoint();
}
- return Type();
+ propagateIfChanged(lattice, changed);
};
- auto toConstantAttr = [&getOperandType](auto enumPair) -> Attribute {
- if (Optional<APInt> maybeConstValue =
- enumPair.value()->getValue().value.getConstantValue()) {
- return IntegerAttr::get(getOperandType(enumPair.index()),
- *maybeConstValue);
- }
- return {};
- };
- SmallVector<Attribute> inferredConsts(
- llvm::map_range(llvm::enumerate(operands), toConstantAttr));
- branch.getSuccessorRegions(sourceIndex, inferredConsts, successors);
+ inferrable.inferResultRanges(argRanges, joinCallback);
}
-ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments(
- Operation *op, const RegionSuccessor ®ion,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
+void IntegerRangeAnalysis::visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor &successor,
+ ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
- LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
- LLVM_DEBUG(inferrable->print(llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
SmallVector<ConstantIntRanges> argRanges(
- llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
- return val->getValue().value;
+ llvm::map_range(op->getOperands(), [&](Value value) {
+ return getLatticeElementFor(op, value)->getValue().getValue();
}));
- ChangeResult result = ChangeResult::NoChange;
auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+ auto arg = v.dyn_cast<BlockArgument>();
+ if (!arg)
+ return;
+ if (llvm::find(successor.getSuccessor()->getArguments(), arg) ==
+ successor.getSuccessor()->args_end())
+ return;
+
LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
- LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
- Optional<IntRangeLattice> oldRange;
- if (!lattice.isUninitialized())
- oldRange = lattice.getValue();
- result |= lattice.join(IntRangeLattice(attrs));
+ IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
+ Optional<IntegerValueRange> oldRange;
+ if (!lattice->isUninitialized())
+ oldRange = lattice->getValue();
+
+ ChangeResult changed = lattice->join(attrs);
// Catch loop results with loop variant bounds and conservatively make
// them [-inf, inf] so we don't circle around infinitely often (because
@@ -268,68 +143,75 @@ ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments(
bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
return op->hasTrait<OpTrait::IsTerminator>();
});
- if (isYieldedValue && oldRange && !(lattice.getValue() == *oldRange)) {
+ if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
- result |= lattice.markPessimisticFixpoint();
+ changed |= lattice->markPessimisticFixpoint();
}
+ propagateIfChanged(lattice, changed);
};
inferrable.inferResultRanges(argRanges, joinCallback);
- for (Value regionArg : region.getSuccessor()->getArguments()) {
- LatticeElement<IntRangeLattice> &lattice = getLatticeElement(regionArg);
- // setResultRange() not called, make pessimistic.
- if (lattice.isUninitialized())
- result |= lattice.markPessimisticFixpoint();
- }
-
- return result;
+ return;
}
+ /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
+ /// on a LoopLikeInterface return the lower/upper bound for that result if
+ /// possible.
+ auto getLoopBoundFromFold = [&](Optional<OpFoldResult> loopBound,
+ Type boundType, bool getUpper) {
+ unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
+ if (loopBound.hasValue()) {
+ if (loopBound->is<Attribute>()) {
+ if (auto bound =
+ loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
+ return bound.getValue();
+ } else if (auto value = loopBound->dyn_cast<Value>()) {
+ const IntegerValueRangeLattice *lattice =
+ getLatticeElementFor(op, value);
+ if (lattice != nullptr)
+ return getUpper ? lattice->getValue().getValue().smax()
+ : lattice->getValue().getValue().smin();
+ }
+ }
+ // Given the results of getConstant{Lower,Upper}Bound()
+ // or getConstantStep() on a LoopLikeInterface return the lower/upper
+ // bound
+ return getUpper ? APInt::getSignedMaxValue(width)
+ : APInt::getSignedMinValue(width);
+ };
+
// Infer bounds for loop arguments that have static bounds
if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
Optional<Value> iv = loop.getSingleInductionVar();
if (!iv) {
- return ForwardDataFlowAnalysis<
- IntRangeLattice>::visitNonControlFlowArguments(op, region, operands);
+ return SparseDataFlowAnalysis ::visitNonControlFlowArguments(
+ op, successor, argLattices, firstIndex);
}
Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
Optional<OpFoldResult> step = loop.getSingleStep();
- APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), *this,
+ APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
/*getUpper=*/false);
- APInt max = getLoopBoundFromFold(upperBound, iv->getType(), *this,
+ APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
/*getUpper=*/true);
// Assume positivity for uniscoverable steps by way of getUpper = true.
APInt stepVal =
- getLoopBoundFromFold(step, iv->getType(), *this, /*getUpper=*/true);
+ getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
if (stepVal.isNegative()) {
std::swap(min, max);
} else {
- // Correct the upper bound by subtracting 1 so that it becomes a <= bound,
- // because loops do not generally include their upper bound.
+ // Correct the upper bound by subtracting 1 so that it becomes a <=
+ // bound, because loops do not generally include their upper bound.
max -= 1;
}
- LatticeElement<IntRangeLattice> &ivEntry = getLatticeElement(*iv);
- return ivEntry.join(ConstantIntRanges::fromSigned(min, max));
+ IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
+ auto ivRange = ConstantIntRanges::fromSigned(min, max);
+ propagateIfChanged(ivEntry, ivEntry->join(ivRange));
+ return;
}
- return ForwardDataFlowAnalysis<IntRangeLattice>::visitNonControlFlowArguments(
- op, region, operands);
-}
-
-IntRangeAnalysis::IntRangeAnalysis(Operation *topLevelOperation) {
- impl = std::make_unique<mlir::detail::IntRangeAnalysisImpl>(
- topLevelOperation->getContext());
- impl->run(topLevelOperation);
-}
-
-IntRangeAnalysis::~IntRangeAnalysis() = default;
-IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default;
-Optional<ConstantIntRanges> IntRangeAnalysis::getResult(Value v) {
- LatticeElement<IntRangeLattice> *result = impl->lookupLatticeElement(v);
- if (result == nullptr || result->isUninitialized())
- return llvm::None;
- return result->getValue().value;
+ return SparseDataFlowAnalysis::visitNonControlFlowArguments(
+ op, successor, argLattices, firstIndex);
}
diff --git a/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
index 80c1293d22729..521a22d1d59fe 100644
--- a/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
@@ -189,10 +189,19 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
walkFn);
}
+/// Returns true if the operation terminates a block. It is insufficient to
+/// check for `OpTrait::IsTerminator` because unregistered operations can be
+/// terminators.
+static bool isTerminator(Operation *op) {
+ if (op->hasTrait<OpTrait::IsTerminator>())
+ return true;
+ return &op->getBlock()->back() == op;
+}
+
LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
// Initialize the analysis by visiting every op with control-flow semantics.
- if (op->getNumRegions() || op->getNumSuccessors() ||
- op->hasTrait<OpTrait::IsTerminator>() || isa<CallOpInterface>(op)) {
+ if (op->getNumRegions() || op->getNumSuccessors() || isTerminator(op) ||
+ isa<CallOpInterface>(op)) {
// When the liveness of the parent block changes, make sure to re-invoke the
// analysis on the op.
if (op->getBlock())
@@ -262,7 +271,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
}
}
- if (op->hasTrait<OpTrait::IsTerminator>() && !op->getNumSuccessors()) {
+ if (isTerminator(op) && !op->getNumSuccessors()) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
// Visit the exiting terminator of a region.
visitRegionTerminator(op, branch);
@@ -593,7 +602,9 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
}
// Otherwise, we can't reason about the data-flow.
- return markAllPessimisticFixpoint(argLattices);
+ return visitNonControlFlowArgumentsImpl(block->getParentOp(),
+ RegionSuccessor(block->getParent()),
+ argLattices, /*firstIndex=*/0);
}
// Iterate over the predecessors of the non-entry block.
@@ -646,7 +657,7 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
operands = branch.getSuccessorEntryOperands(successorIndex);
// Otherwise, try to deduce the operands from a region return-like op.
} else {
- assert(op->hasTrait<OpTrait::IsTerminator>() && "expected a terminator");
+ assert(isTerminator(op) && "expected a terminator");
if (isRegionReturnLike(op))
operands = getRegionBranchSuccessorOperands(op, successorIndex);
}
@@ -660,17 +671,26 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
assert(inputs.size() == operands->size() &&
"expected the same number of successor inputs as operands");
- // TODO: This was updated to be exposed upstream.
unsigned firstIndex = 0;
if (inputs.size() != lattices.size()) {
- if (inputs.empty()) {
- markAllPessimisticFixpoint(lattices);
- return;
+ if (auto *op = point.dyn_cast<Operation *>()) {
+ if (!inputs.empty())
+ firstIndex = inputs.front().cast<OpResult>().getResultNumber();
+ visitNonControlFlowArgumentsImpl(
+ branch,
+ RegionSuccessor(
+ branch->getResults().slice(firstIndex, inputs.size())),
+ lattices, firstIndex);
+ } else {
+ if (!inputs.empty())
+ firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
+ Region *region = point.get<Block *>()->getParent();
+ visitNonControlFlowArgumentsImpl(
+ branch,
+ RegionSuccessor(region, region->getArguments().slice(
+ firstIndex, inputs.size())),
+ lattices, firstIndex);
}
- firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
- markAllPessimisticFixpoint(lattices.take_front(firstIndex));
- markAllPessimisticFixpoint(
- lattices.drop_front(firstIndex + inputs.size()));
}
for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
@@ -717,7 +737,7 @@ void SparseConstantPropagation::visitOperation(
// folds as the desire here is for simulated execution, and not general
// folding.
if (op->getNumRegions())
- return;
+ return markAllPessimisticFixpoint(results);
SmallVector<Attribute, 8> constantOperands;
constantOperands.reserve(op->getNumOperands());
@@ -734,10 +754,8 @@ void SparseConstantPropagation::visitOperation(
// fails or was an in-place fold, mark the results as overdefined.
SmallVector<OpFoldResult, 8> foldResults;
foldResults.reserve(op->getNumResults());
- if (failed(op->fold(constantOperands, foldResults))) {
- markAllPessimisticFixpoint(results);
- return;
- }
+ if (failed(op->fold(constantOperands, foldResults)))
+ return markAllPessimisticFixpoint(results);
// If the folding was in-place, mark the results as overdefined and reset
// the operation. We don't allow in-place folds as the desire here is for
@@ -745,7 +763,7 @@ void SparseConstantPropagation::visitOperation(
if (foldResults.empty()) {
op->setOperands(originalOperands);
op->setAttrs(originalAttrs);
- return;
+ return markAllPessimisticFixpoint(results);
}
// Merge the fold results into the lattice for this operation.
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
index f84990d0a8c47..599c1fe592ef6 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
@@ -20,22 +20,21 @@ using namespace mlir::arith;
/// Succeeds when a value is statically non-negative in that it has a lower
/// bound on its value (if it is treated as signed) and that bound is
/// non-negative.
-static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
- Value v) {
- Optional<ConstantIntRanges> result = analysis.getResult(v);
- if (!result.hasValue())
+static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
+ auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
+ if (!result)
return failure();
- const ConstantIntRanges &range = result.getValue();
+ const ConstantIntRanges &range = result->getValue().getValue();
return success(range.smin().isNonNegative());
}
/// Succeeds if an op can be converted to its unsigned equivalent without
/// changing its semantics. This is the case when none of its openands or
/// results can be below 0 when analyzed from a signed perspective.
-static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
+static LogicalResult staticallyNonNegative(DataFlowSolver &solver,
Operation *op) {
- auto nonNegativePred = [&analysis](Value v) -> bool {
- return succeeded(staticallyNonNegative(analysis, v));
+ auto nonNegativePred = [&solver](Value v) -> bool {
+ return succeeded(staticallyNonNegative(solver, v));
};
return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
llvm::all_of(op->getResults(), nonNegativePred));
@@ -44,15 +43,15 @@ static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
/// Succeeds when the comparison predicate is a signed operation and all the
/// operands are non-negative, indicating that the cmpi operation `op` can have
/// its predicate changed to an unsigned equivalent.
-static LogicalResult isCmpIConvertable(IntRangeAnalysis &analysis, CmpIOp op) {
+static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
CmpIPredicate pred = op.getPredicate();
switch (pred) {
case CmpIPredicate::sle:
case CmpIPredicate::slt:
case CmpIPredicate::sge:
case CmpIPredicate::sgt:
- return success(llvm::all_of(op.getOperands(), [&analysis](Value v) -> bool {
- return succeeded(staticallyNonNegative(analysis, v));
+ return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
+ return succeeded(staticallyNonNegative(solver, v));
}));
default:
return failure();
@@ -109,19 +108,23 @@ struct ArithmeticUnsignedWhenEquivalentPass
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
- IntRangeAnalysis analysis(op);
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<IntegerRangeAnalysis>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
ConversionTarget target(*ctx);
target.addLegalDialect<ArithmeticDialect>();
target
.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp,
RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>(
- [&analysis](Operation *op) -> Optional<bool> {
- return failed(staticallyNonNegative(analysis, op));
+ [&solver](Operation *op) -> Optional<bool> {
+ return failed(staticallyNonNegative(solver, op));
});
target.addDynamicallyLegalOp<CmpIOp>(
- [&analysis](CmpIOp op) -> Optional<bool> {
- return failed(isCmpIConvertable(analysis, op));
+ [&solver](CmpIOp op) -> Optional<bool> {
+ return failed(isCmpIConvertable(solver, op));
});
RewritePatternSet patterns(ctx);
diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 548950aaed81a..078a82b7de5f2 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -36,7 +36,7 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver,
OpBuilder &builder,
OperationFolder &folder, Value value) {
auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
- if (!lattice)
+ if (!lattice || lattice->isUninitialized())
return failure();
const ConstantValue &latticeValue = lattice->getValue();
if (!latticeValue.getConstantValue())
diff --git a/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp
index a7c33526c80de..e26eda9c831a4 100644
--- a/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp
+++ b/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp
@@ -66,9 +66,8 @@ struct ConstantAnalysis : public DataFlowAnalysis {
LogicalResult initialize(Operation *top) override {
WalkResult result = top->walk([&](Operation *op) {
- if (op->hasTrait<OpTrait::ConstantLike>())
- if (failed(visit(op)))
- return WalkResult::interrupt();
+ if (failed(visit(op)))
+ return WalkResult::interrupt();
return WalkResult::advance();
});
return success(!result.wasInterrupted());
@@ -81,13 +80,27 @@ struct ConstantAnalysis : public DataFlowAnalysis {
auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
propagateIfChanged(
constant, constant->join(ConstantValue(value, op->getDialect())));
+ return success();
}
+ markAllPessimisticFixpoint(op->getResults());
+ for (Region ®ion : op->getRegions())
+ markAllPessimisticFixpoint(region.getArguments());
return success();
}
+
+ /// Mark the constant values of all given values as having reached a
+ /// pessimistic fixpoint.
+ void markAllPessimisticFixpoint(ValueRange values) {
+ for (Value value : values) {
+ auto *constantValue = getOrCreate<Lattice<ConstantValue>>(value);
+ propagateIfChanged(constantValue,
+ constantValue->markPessimisticFixpoint());
+ }
+ }
};
-/// This is a simple pass that runs dead code analysis with no constant value
-/// provider. It marks everything as live.
+/// This is a simple pass that runs dead code analysis with a constant value
+/// provider that only understands constant operations.
struct TestDeadCodeAnalysisPass
: public PassWrapper<TestDeadCodeAnalysisPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass)
diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
index 1bd2a24d3ce6c..1f75007fcf9a3 100644
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
@@ -19,13 +19,14 @@
using namespace mlir;
/// Patterned after SCCP
-static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
- OpBuilder &b, OperationFolder &folder,
- Value value) {
- Optional<ConstantIntRanges> maybeInferredRange = analysis.getResult(value);
- if (!maybeInferredRange)
+static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
+ OperationFolder &folder, Value value) {
+ auto *maybeInferredRange =
+ solver.lookupState<IntegerValueRangeLattice>(value);
+ if (!maybeInferredRange || maybeInferredRange->isUninitialized())
return failure();
- const ConstantIntRanges &inferredRange = maybeInferredRange.getValue();
+ const ConstantIntRanges &inferredRange =
+ maybeInferredRange->getValue().getValue();
Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
if (!maybeConstValue.hasValue())
return failure();
@@ -44,7 +45,7 @@ static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
return success();
}
-static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
+static void rewrite(DataFlowSolver &solver, MLIRContext *context,
MutableArrayRef<Region> initialRegions) {
SmallVector<Block *> worklist;
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
@@ -67,7 +68,7 @@ static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
bool replacedAll = op.getNumResults() != 0;
for (Value res : op.getResults())
replacedAll &=
- succeeded(replaceWithConstant(analysis, builder, folder, res));
+ succeeded(replaceWithConstant(solver, builder, folder, res));
// If all of the results of the operation were replaced, try to erase
// the operation completely.
@@ -84,7 +85,7 @@ static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
// Replace any block arguments with constants.
builder.setInsertionPointToStart(block);
for (BlockArgument arg : block->getArguments())
- (void)replaceWithConstant(analysis, builder, folder, arg);
+ (void)replaceWithConstant(solver, builder, folder, arg);
}
}
@@ -100,8 +101,12 @@ struct TestIntRangeInference
void runOnOperation() override {
Operation *op = getOperation();
- IntRangeAnalysis analysis(op);
- rewrite(analysis, op->getContext(), op->getRegions());
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<IntegerRangeAnalysis>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
+ rewrite(solver, op->getContext(), op->getRegions());
}
};
} // end anonymous namespace
More information about the llvm-branch-commits
mailing list