[Mlir-commits] [mlir] ab70197 - [mlir] Swap integer range inference to the new framework
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 7 20:28:18 PDT 2022
Author: Mogball
Date: 2022-07-07T20:28:13-07:00
New Revision: ab701975e7f3b63bb474afbdeb8c474950d41074
URL: https://github.com/llvm/llvm-project/commit/ab701975e7f3b63bb474afbdeb8c474950d41074
DIFF: https://github.com/llvm/llvm-project/commit/ab701975e7f3b63bb474afbdeb8c474950d41074.diff
LOG: [mlir] Swap integer range inference to the new framework
Integer range inference has been swapped to the new framework. The integer value range lattices automatically updates the corresponding constant value on update.
Depends on D127173
Reviewed By: krzysz00, rriddle
Differential Revision: https://reviews.llvm.org/D128866
Added:
mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
Modified:
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/include/mlir/Analysis/DataFlowFramework.h
mlir/include/mlir/Interfaces/InferIntRangeInterface.td
mlir/lib/Analysis/CMakeLists.txt
mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Analysis/DataFlowFramework.cpp
mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
mlir/lib/Transforms/SCCP.cpp
mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
mlir/test/lib/Transforms/TestIntRangeInference.cpp
Removed:
mlir/include/mlir/Analysis/IntRangeAnalysis.h
mlir/lib/Analysis/IntRangeAnalysis.cpp
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
new file mode 100644
index 0000000000000..3cd007ab478ba
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h
@@ -0,0 +1,97 @@
+//===-IntegerRangeAnalysis.h - Integer range analysis -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the dataflow analysis class for integer range inference
+// so that it can be used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H
+#define MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H
+
+#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+
+namespace mlir {
+namespace dataflow {
+
+/// This lattice value represents the integer range of an SSA value.
+class IntegerValueRange {
+public:
+ /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
+ /// range that is used to mark the value as unable to be analyzed further,
+ /// where `t` is the type of `value`.
+ static IntegerValueRange getPessimisticValueState(Value value);
+
+ /// Create an integer value range lattice value.
+ IntegerValueRange(ConstantIntRanges value) : value(std::move(value)) {}
+
+ /// Get the known integer value range.
+ const ConstantIntRanges &getValue() const { return value; }
+
+ /// Compare two ranges.
+ bool operator==(const IntegerValueRange &rhs) const {
+ return value == rhs.value;
+ }
+
+ /// Take the union of two ranges.
+ static IntegerValueRange join(const IntegerValueRange &lhs,
+ const IntegerValueRange &rhs) {
+ return lhs.value.rangeUnion(rhs.value);
+ }
+
+ /// Print the integer value range.
+ void print(raw_ostream &os) const { os << value; }
+
+private:
+ /// The known integer value range.
+ ConstantIntRanges value;
+};
+
+/// This lattice element represents the integer value range of an SSA value.
+/// When this lattice is updated, it automatically updates the constant value
+/// of the SSA value (if the range can be narrowed to one).
+class IntegerValueRangeLattice : public Lattice<IntegerValueRange> {
+public:
+ using Lattice::Lattice;
+
+ /// If the range can be narrowed to an integer constant, update the constant
+ /// value of the SSA value.
+ void onUpdate(DataFlowSolver *solver) const override;
+};
+
+/// Integer range analysis determines the integer value range of SSA values
+/// using operations that define `InferIntRangeInterface` and also sets the
+/// range of iteration indices of loops with known bounds.
+class IntegerRangeAnalysis
+ : public SparseDataFlowAnalysis<IntegerValueRangeLattice> {
+public:
+ using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
+
+ /// Visit an operation. Invoke the transfer function on each operation that
+ /// implements `InferIntRangeInterface`.
+ void visitOperation(Operation *op,
+ ArrayRef<const IntegerValueRangeLattice *> operands,
+ ArrayRef<IntegerValueRangeLattice *> results) override;
+
+ /// Visit block arguments or operation results of an operation with region
+ /// control-flow for which values are not defined by region control-flow. This
+ /// function calls `InferIntRangeInterface` to provide values for block
+ /// arguments or tries to reduce the range on loop induction variables with
+ /// known bounds.
+ void
+ visitNonControlFlowArguments(Operation *op, const RegionSuccessor &successor,
+ ArrayRef<IntegerValueRangeLattice *> argLattices,
+ unsigned firstIndex) override;
+};
+
+} // end namespace dataflow
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_DATAFLOW_INTEGERANGEANALYSIS_H
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 6456f7d6cec22..003a226b141cc 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -16,12 +16,10 @@
#define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"
namespace mlir {
-
-class RegionBranchOpInterface;
-
namespace dataflow {
//===----------------------------------------------------------------------===//
@@ -213,6 +211,14 @@ class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
+ /// Given an operation with region control-flow, the lattices of the operands,
+ /// and a region successor, compute the lattice values for block arguments
+ /// that are not accounted for by the branching control flow (ex. the bounds
+ /// of loops).
+ virtual void visitNonControlFlowArgumentsImpl(
+ Operation *op, const RegionSuccessor &successor,
+ ArrayRef<AbstractSparseLattice *> argLattices, unsigned firstIndex) = 0;
+
/// Get the lattice element of a value.
virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
@@ -276,6 +282,21 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
ArrayRef<StateT *> results) = 0;
+ /// Given an operation with possible region control-flow, the lattices of the
+ /// operands, and a region successor, compute the lattice values for block
+ /// arguments that are not accounted for by the branching control flow (ex.
+ /// the bounds of loops). By default, this method marks all such lattice
+ /// elements as having reached a pessimistic fixpoint. `firstIndex` is the
+ /// index of the first element of `argLattices` that is set by control-flow.
+ virtual void visitNonControlFlowArguments(Operation *op,
+ const RegionSuccessor &successor,
+ ArrayRef<StateT *> argLattices,
+ unsigned firstIndex) {
+ markAllPessimisticFixpoint(argLattices.take_front(firstIndex));
+ markAllPessimisticFixpoint(argLattices.drop_front(
+ firstIndex + successor.getSuccessorInputs().size()));
+ }
+
protected:
/// Get the lattice element for a value.
StateT *getLatticeElement(Value value) override {
@@ -310,6 +331,16 @@ class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
{reinterpret_cast<StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
+ void visitNonControlFlowArgumentsImpl(
+ Operation *op, const RegionSuccessor &successor,
+ ArrayRef<AbstractSparseLattice *> argLattices,
+ unsigned firstIndex) override {
+ visitNonControlFlowArguments(
+ op, successor,
+ {reinterpret_cast<StateT *const *>(argLattices.begin()),
+ argLattices.size()},
+ firstIndex);
+ }
};
} // end namespace dataflow
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 19d8fc0c3e19b..2992e05f14ddf 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -226,7 +226,6 @@ class DataFlowSolver {
/// Push a work item onto the worklist.
void enqueue(WorkItem item) { worklist.push(std::move(item)); }
-protected:
/// Get the state associated with the given program point. If it does not
/// exist, create an uninitialized state.
template <typename StateT, typename PointT>
diff --git a/mlir/include/mlir/Analysis/IntRangeAnalysis.h b/mlir/include/mlir/Analysis/IntRangeAnalysis.h
deleted file mode 100644
index b2b604359b48b..0000000000000
--- a/mlir/include/mlir/Analysis/IntRangeAnalysis.h
+++ /dev/null
@@ -1,41 +0,0 @@
-//===- IntRangeAnalysis.h - Infer Ranges Interfaces --*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file declares the dataflow analysis class for integer range inference
-// so that it can be used in transformations over the `arith` dialect such as
-// branch elimination or signed->unsigned rewriting
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_ANALYSIS_INTRANGEANALYSIS_H
-#define MLIR_ANALYSIS_INTRANGEANALYSIS_H
-
-#include "mlir/Interfaces/InferIntRangeInterface.h"
-
-namespace mlir {
-namespace detail {
-class IntRangeAnalysisImpl;
-} // end namespace detail
-
-class IntRangeAnalysis {
-public:
- /// Analyze all operations rooted under (but not including)
- /// `topLevelOperation`.
- IntRangeAnalysis(Operation *topLevelOperation);
- IntRangeAnalysis(IntRangeAnalysis &&other);
- ~IntRangeAnalysis();
-
- /// Get inferred range for value `v` if one exists.
- Optional<ConstantIntRanges> getResult(Value v);
-
-private:
- std::unique_ptr<detail::IntRangeAnalysisImpl> impl;
-};
-} // end namespace mlir
-
-#endif
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
index 57f8d693b7916..abe6df1543625 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.td
@@ -30,7 +30,7 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
since the dataflow analysis handles those case), the method should call
`setValueRange` with that `Value` as an argument. When `setValueRange`
is not called for some value, it will recieve a default value of the mimimum
- and maximum values forits type (the unbounded range).
+ and maximum values for its type (the unbounded range).
When called on an op that also implements the RegionBranchOpInterface
or BranchOpInterface, this method should not attempt to infer the values
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index efac97d665e7a..26b8ea6f155c7 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -4,7 +4,6 @@ set(LLVM_OPTIONAL_SOURCES
CallGraph.cpp
DataFlowAnalysis.cpp
DataLayoutAnalysis.cpp
- IntRangeAnalysis.cpp
Liveness.cpp
SliceAnalysis.cpp
@@ -13,6 +12,7 @@ set(LLVM_OPTIONAL_SOURCES
DataFlow/ConstantPropagationAnalysis.cpp
DataFlow/DeadCodeAnalysis.cpp
DataFlow/DenseAnalysis.cpp
+ DataFlow/IntegerRangeAnalysis.cpp
DataFlow/SparseAnalysis.cpp
)
@@ -23,7 +23,6 @@ add_mlir_library(MLIRAnalysis
DataFlowAnalysis.cpp
DataFlowFramework.cpp
DataLayoutAnalysis.cpp
- IntRangeAnalysis.cpp
Liveness.cpp
SliceAnalysis.cpp
@@ -32,6 +31,7 @@ add_mlir_library(MLIRAnalysis
DataFlow/ConstantPropagationAnalysis.cpp
DataFlow/DeadCodeAnalysis.cpp
DataFlow/DenseAnalysis.cpp
+ DataFlow/IntegerRangeAnalysis.cpp
DataFlow/SparseAnalysis.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 1035c21219e7a..327ef4f90b514 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -170,10 +170,18 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
walkFn);
}
+/// Returns true if the operation is a returning terminator in region
+/// control-flow or the terminator of a callable region.
+static bool isRegionOrCallableReturn(Operation *op) {
+ return !op->getNumSuccessors() &&
+ isa<RegionBranchOpInterface, CallableOpInterface>(op->getParentOp()) &&
+ op->getBlock()->getTerminator() == op;
+}
+
LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
// Initialize the analysis by visiting every op with control-flow semantics.
if (op->getNumRegions() || op->getNumSuccessors() ||
- op->hasTrait<OpTrait::IsTerminator>() || isa<CallOpInterface>(op)) {
+ isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) {
// When the liveness of the parent block changes, make sure to re-invoke the
// analysis on the op.
if (op->getBlock())
@@ -243,7 +251,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
}
}
- if (op->hasTrait<OpTrait::IsTerminator>() && !op->getNumSuccessors()) {
+ if (isRegionOrCallableReturn(op)) {
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
// Visit the exiting terminator of a region.
visitRegionTerminator(op, branch);
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
new file mode 100644
index 0000000000000..e983341faf02e
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -0,0 +1,219 @@
+//===- IntegerRangeAnalysis.cpp - Integer range analysis --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the dataflow analysis class for integer range inference
+// which is used in transformations over the `arith` dialect such as
+// branch elimination or signed->unsigned rewriting
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "int-range-analysis"
+
+using namespace mlir;
+using namespace mlir::dataflow;
+
+IntegerValueRange IntegerValueRange::getPessimisticValueState(Value value) {
+ unsigned width = ConstantIntRanges::getStorageBitwidth(value.getType());
+ APInt umin = APInt::getMinValue(width);
+ APInt umax = APInt::getMaxValue(width);
+ APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
+ APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
+ return {{umin, umax, smin, smax}};
+}
+
+void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
+ Lattice::onUpdate(solver);
+
+ // If the integer range can be narrowed to a constant, update the constant
+ // value of the SSA value.
+ Optional<APInt> constant = getValue().getValue().getConstantValue();
+ auto value = point.get<Value>();
+ auto *cv = solver->getOrCreateState<Lattice<ConstantValue>>(value);
+ if (!constant)
+ return solver->propagateIfChanged(cv, cv->markPessimisticFixpoint());
+
+ Dialect *dialect;
+ if (auto *parent = value.getDefiningOp())
+ dialect = parent->getDialect();
+ else
+ dialect = value.getParentBlock()->getParentOp()->getDialect();
+ solver->propagateIfChanged(
+ cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant),
+ dialect)));
+}
+
+void IntegerRangeAnalysis::visitOperation(
+ Operation *op, ArrayRef<const IntegerValueRangeLattice *> operands,
+ ArrayRef<IntegerValueRangeLattice *> results) {
+ // Ignore non-integer outputs - return early if the op has no scalar
+ // integer results
+ bool hasIntegerResult = false;
+ for (auto it : llvm::zip(results, op->getResults())) {
+ if (std::get<1>(it).getType().isIntOrIndex()) {
+ hasIntegerResult = true;
+ } else {
+ propagateIfChanged(std::get<0>(it),
+ std::get<0>(it)->markPessimisticFixpoint());
+ }
+ }
+ if (!hasIntegerResult)
+ return;
+
+ auto inferrable = dyn_cast<InferIntRangeInterface>(op);
+ if (!inferrable)
+ return markAllPessimisticFixpoint(results);
+
+ LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
+ SmallVector<ConstantIntRanges> argRanges(
+ llvm::map_range(operands, [](const IntegerValueRangeLattice *val) {
+ return val->getValue().getValue();
+ }));
+
+ auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+ auto result = v.dyn_cast<OpResult>();
+ if (!result)
+ return;
+ assert(llvm::find(op->getResults(), result) != op->result_end());
+
+ LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+ IntegerValueRangeLattice *lattice = results[result.getResultNumber()];
+ Optional<IntegerValueRange> oldRange;
+ if (!lattice->isUninitialized())
+ oldRange = lattice->getValue();
+
+ ChangeResult changed = lattice->join(attrs);
+
+ // Catch loop results with loop variant bounds and conservatively make
+ // them [-inf, inf] so we don't circle around infinitely often (because
+ // the dataflow analysis in MLIR doesn't attempt to work out trip counts
+ // and often can't).
+ bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
+ return op->hasTrait<OpTrait::IsTerminator>();
+ });
+ if (isYieldedResult && oldRange.hasValue() &&
+ !(lattice->getValue() == *oldRange)) {
+ LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ changed |= lattice->markPessimisticFixpoint();
+ }
+ propagateIfChanged(lattice, changed);
+ };
+
+ inferrable.inferResultRanges(argRanges, joinCallback);
+}
+
+void IntegerRangeAnalysis::visitNonControlFlowArguments(
+ Operation *op, const RegionSuccessor &successor,
+ ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) {
+ if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
+ LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n");
+ SmallVector<ConstantIntRanges> argRanges(
+ llvm::map_range(op->getOperands(), [&](Value value) {
+ return getLatticeElementFor(op, value)->getValue().getValue();
+ }));
+
+ auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
+ auto arg = v.dyn_cast<BlockArgument>();
+ if (!arg)
+ return;
+ if (llvm::find(successor.getSuccessor()->getArguments(), arg) ==
+ successor.getSuccessor()->args_end())
+ return;
+
+ LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
+ IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()];
+ Optional<IntegerValueRange> oldRange;
+ if (!lattice->isUninitialized())
+ oldRange = lattice->getValue();
+
+ ChangeResult changed = lattice->join(attrs);
+
+ // Catch loop results with loop variant bounds and conservatively make
+ // them [-inf, inf] so we don't circle around infinitely often (because
+ // the dataflow analysis in MLIR doesn't attempt to work out trip counts
+ // and often can't).
+ bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
+ return op->hasTrait<OpTrait::IsTerminator>();
+ });
+ if (isYieldedValue && oldRange && !(lattice->getValue() == *oldRange)) {
+ LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
+ changed |= lattice->markPessimisticFixpoint();
+ }
+ propagateIfChanged(lattice, changed);
+ };
+
+ inferrable.inferResultRanges(argRanges, joinCallback);
+ return;
+ }
+
+ /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
+ /// on a LoopLikeInterface return the lower/upper bound for that result if
+ /// possible.
+ auto getLoopBoundFromFold = [&](Optional<OpFoldResult> loopBound,
+ Type boundType, bool getUpper) {
+ unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
+ if (loopBound.hasValue()) {
+ if (loopBound->is<Attribute>()) {
+ if (auto bound =
+ loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
+ return bound.getValue();
+ } else if (auto value = loopBound->dyn_cast<Value>()) {
+ const IntegerValueRangeLattice *lattice =
+ getLatticeElementFor(op, value);
+ if (lattice != nullptr)
+ return getUpper ? lattice->getValue().getValue().smax()
+ : lattice->getValue().getValue().smin();
+ }
+ }
+ // Given the results of getConstant{Lower,Upper}Bound()
+ // or getConstantStep() on a LoopLikeInterface return the lower/upper
+ // bound
+ return getUpper ? APInt::getSignedMaxValue(width)
+ : APInt::getSignedMinValue(width);
+ };
+
+ // Infer bounds for loop arguments that have static bounds
+ if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
+ Optional<Value> iv = loop.getSingleInductionVar();
+ if (!iv) {
+ return SparseDataFlowAnalysis ::visitNonControlFlowArguments(
+ op, successor, argLattices, firstIndex);
+ }
+ Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
+ Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
+ Optional<OpFoldResult> step = loop.getSingleStep();
+ APInt min = getLoopBoundFromFold(lowerBound, iv->getType(),
+ /*getUpper=*/false);
+ APInt max = getLoopBoundFromFold(upperBound, iv->getType(),
+ /*getUpper=*/true);
+ // Assume positivity for uniscoverable steps by way of getUpper = true.
+ APInt stepVal =
+ getLoopBoundFromFold(step, iv->getType(), /*getUpper=*/true);
+
+ if (stepVal.isNegative()) {
+ std::swap(min, max);
+ } else {
+ // Correct the upper bound by subtracting 1 so that it becomes a <=
+ // bound, because loops do not generally include their upper bound.
+ max -= 1;
+ }
+
+ IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
+ auto ivRange = ConstantIntRanges::fromSigned(min, max);
+ propagateIfChanged(ivEntry, ivEntry->join(ivRange));
+ return;
+ }
+
+ return SparseDataFlowAnalysis::visitNonControlFlowArguments(
+ op, successor, argLattices, firstIndex);
+}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 35487c1e1de82..776f284dfca2d 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -9,7 +9,6 @@
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Interfaces/CallInterfaces.h"
-#include "mlir/Interfaces/ControlFlowInterfaces.h"
using namespace mlir;
using namespace mlir::dataflow;
@@ -183,7 +182,9 @@ void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
}
// Otherwise, we can't reason about the data-flow.
- return markAllPessimisticFixpoint(argLattices);
+ return visitNonControlFlowArgumentsImpl(block->getParentOp(),
+ RegionSuccessor(block->getParent()),
+ argLattices, /*firstIndex=*/0);
}
// Iterate over the predecessors of the non-entry block.
@@ -236,7 +237,6 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
operands = branch.getSuccessorEntryOperands(successorIndex);
// Otherwise, try to deduce the operands from a region return-like op.
} else {
- assert(op->hasTrait<OpTrait::IsTerminator>() && "expected a terminator");
if (isRegionReturnLike(op))
operands = getRegionBranchSuccessorOperands(op, successorIndex);
}
@@ -250,17 +250,26 @@ void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
assert(inputs.size() == operands->size() &&
"expected the same number of successor inputs as operands");
- // TODO: This was updated to be exposed upstream.
unsigned firstIndex = 0;
if (inputs.size() != lattices.size()) {
- if (inputs.empty()) {
- markAllPessimisticFixpoint(lattices);
- return;
+ if (auto *op = point.dyn_cast<Operation *>()) {
+ if (!inputs.empty())
+ firstIndex = inputs.front().cast<OpResult>().getResultNumber();
+ visitNonControlFlowArgumentsImpl(
+ branch,
+ RegionSuccessor(
+ branch->getResults().slice(firstIndex, inputs.size())),
+ lattices, firstIndex);
+ } else {
+ if (!inputs.empty())
+ firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
+ Region *region = point.get<Block *>()->getParent();
+ visitNonControlFlowArgumentsImpl(
+ branch,
+ RegionSuccessor(region, region->getArguments().slice(
+ firstIndex, inputs.size())),
+ lattices, firstIndex);
}
- firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
- markAllPessimisticFixpoint(lattices.take_front(firstIndex));
- markAllPessimisticFixpoint(
- lattices.drop_front(firstIndex + inputs.size()));
}
for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index be18432468d4f..18d9ba1bd5d60 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -87,19 +87,6 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
return failure();
}
- // "Nudge" the state of the analysis by forcefully initializing states that
- // are still uninitialized. All uninitialized states in the graph can be
- // initialized in any order because the analysis reached fixpoint, meaning
- // that there are no work items that would have further nudged the analysis.
- for (AnalysisState &state :
- llvm::make_pointee_range(llvm::make_second_range(analysisStates))) {
- if (!state.isUninitialized())
- continue;
- DATAFLOW_DEBUG(llvm::dbgs() << "Default initializing " << state.debugName
- << " of " << state.point << "\n");
- propagateIfChanged(&state, state.defaultInitialize());
- }
-
// Iterate until all states are in some initialized state and the worklist
// is exhausted.
} while (!worklist.empty());
diff --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp
deleted file mode 100644
index f887d68d12ec2..0000000000000
--- a/mlir/lib/Analysis/IntRangeAnalysis.cpp
+++ /dev/null
@@ -1,335 +0,0 @@
-//===- IntRangeAnalysis.cpp - Infer Ranges Interfaces --*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the dataflow analysis class for integer range inference
-// which is used in transformations over the `arith` dialect such as
-// branch elimination or signed->unsigned rewriting
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Analysis/IntRangeAnalysis.h"
-#include "mlir/Analysis/DataFlowAnalysis.h"
-#include "mlir/Interfaces/InferIntRangeInterface.h"
-#include "mlir/Interfaces/LoopLikeInterface.h"
-#include "llvm/Support/Debug.h"
-
-#define DEBUG_TYPE "int-range-analysis"
-
-using namespace mlir;
-
-namespace {
-/// A wrapper around ConstantIntRanges that provides the lattice functions
-/// expected by dataflow analysis.
-struct IntRangeLattice {
- IntRangeLattice(const ConstantIntRanges &value) : value(value){};
- IntRangeLattice(ConstantIntRanges &&value) : value(value){};
-
- bool operator==(const IntRangeLattice &other) const {
- return value == other.value;
- }
-
- /// wrapper around rangeUnion()
- static IntRangeLattice join(const IntRangeLattice &a,
- const IntRangeLattice &b) {
- return a.value.rangeUnion(b.value);
- }
-
- /// Creates a range with bitwidth 0 to represent that we don't know if the
- /// value being marked overdefined is even an integer.
- static IntRangeLattice getPessimisticValueState(MLIRContext *context) {
- APInt noIntValue = APInt::getZeroWidth();
- return ConstantIntRanges(noIntValue, noIntValue, noIntValue, noIntValue);
- }
-
- /// Create a maximal range ([0, uint_max(t)] / [int_min(t), int_max(t)])
- /// range that is used to mark the value v as unable to be analyzed further,
- /// where t is the type of v.
- static IntRangeLattice getPessimisticValueState(Value v) {
- unsigned int width = ConstantIntRanges::getStorageBitwidth(v.getType());
- APInt umin = APInt::getMinValue(width);
- APInt umax = APInt::getMaxValue(width);
- APInt smin = width != 0 ? APInt::getSignedMinValue(width) : umin;
- APInt smax = width != 0 ? APInt::getSignedMaxValue(width) : umax;
- return ConstantIntRanges{umin, umax, smin, smax};
- }
-
- ConstantIntRanges value;
-};
-} // end anonymous namespace
-
-namespace mlir {
-namespace detail {
-class IntRangeAnalysisImpl : public ForwardDataFlowAnalysis<IntRangeLattice> {
- using ForwardDataFlowAnalysis<IntRangeLattice>::ForwardDataFlowAnalysis;
-
-public:
- /// Define bounds on the results or block arguments of the operation
- /// based on the bounds on the arguments given in `operands`
- ChangeResult
- visitOperation(Operation *op,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
-
- /// Skip regions of branch ops when we can statically infer constant
- /// values for operands to the branch op and said op tells us it's safe to do
- /// so.
- LogicalResult
- getSuccessorsForOperands(BranchOpInterface branch,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands,
- SmallVectorImpl<Block *> &successors) final;
-
- /// Skip regions of branch or loop ops when we can statically infer constant
- /// values for operands to the branch op and said op tells us it's safe to do
- /// so.
- void
- getSuccessorsForOperands(RegionBranchOpInterface branch,
- Optional<unsigned> sourceIndex,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands,
- SmallVectorImpl<RegionSuccessor> &successors) final;
-
- /// Call the InferIntRangeInterface implementation for region-using ops
- /// that implement it, and infer the bounds of loop induction variables
- /// for ops that implement LoopLikeOPInterface.
- ChangeResult visitNonControlFlowArguments(
- Operation *op, const RegionSuccessor ®ion,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands) final;
-};
-} // end namespace detail
-} // end namespace mlir
-
-/// Given the results of getConstant{Lower,Upper}Bound()
-/// or getConstantStep() on a LoopLikeInterface return the lower/upper bound for
-/// that result if possible.
-static APInt getLoopBoundFromFold(Optional<OpFoldResult> loopBound,
- Type boundType,
- detail::IntRangeAnalysisImpl &analysis,
- bool getUpper) {
- unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
- if (loopBound) {
- if (loopBound->is<Attribute>()) {
- if (auto bound =
- loopBound->get<Attribute>().dyn_cast_or_null<IntegerAttr>())
- return bound.getValue();
- } else if (loopBound->is<Value>()) {
- LatticeElement<IntRangeLattice> *lattice =
- analysis.lookupLatticeElement(loopBound->get<Value>());
- if (lattice != nullptr)
- return getUpper ? lattice->getValue().value.smax()
- : lattice->getValue().value.smin();
- }
- }
- return getUpper ? APInt::getSignedMaxValue(width)
- : APInt::getSignedMinValue(width);
-}
-
-ChangeResult detail::IntRangeAnalysisImpl::visitOperation(
- Operation *op, ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
- ChangeResult result = ChangeResult::NoChange;
- // Ignore non-integer outputs - return early if the op has no scalar
- // integer results
- bool hasIntegerResult = false;
- for (Value v : op->getResults()) {
- if (v.getType().isIntOrIndex())
- hasIntegerResult = true;
- else
- result |= markAllPessimisticFixpoint(v);
- }
- if (!hasIntegerResult)
- return result;
-
- if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
- LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
- LLVM_DEBUG(inferrable->print(llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
- SmallVector<ConstantIntRanges> argRanges(
- llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
- return val->getValue().value;
- }));
-
- auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
- LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
- LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
- Optional<IntRangeLattice> oldRange;
- if (!lattice.isUninitialized())
- oldRange = lattice.getValue();
- result |= lattice.join(IntRangeLattice(attrs));
-
- // Catch loop results with loop variant bounds and conservatively make
- // them [-inf, inf] so we don't circle around infinitely often (because
- // the dataflow analysis in MLIR doesn't attempt to work out trip counts
- // and often can't).
- bool isYieldedResult = llvm::any_of(v.getUsers(), [](Operation *op) {
- return op->hasTrait<OpTrait::IsTerminator>();
- });
- if (isYieldedResult && oldRange && !(lattice.getValue() == *oldRange)) {
- LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
- result |= lattice.markPessimisticFixpoint();
- }
- };
-
- inferrable.inferResultRanges(argRanges, joinCallback);
- for (Value opResult : op->getResults()) {
- LatticeElement<IntRangeLattice> &lattice = getLatticeElement(opResult);
- // setResultRange() not called, make pessimistic.
- if (lattice.isUninitialized())
- result |= lattice.markPessimisticFixpoint();
- }
- } else if (op->getNumRegions() == 0) {
- // No regions + no result inference method -> unbounded results (ex. memory
- // ops)
- result |= markAllPessimisticFixpoint(op->getResults());
- }
- return result;
-}
-
-LogicalResult detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
- BranchOpInterface branch,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands,
- SmallVectorImpl<Block *> &successors) {
- auto toConstantAttr = [&branch](auto enumPair) -> Attribute {
- Optional<APInt> maybeConstValue =
- enumPair.value()->getValue().value.getConstantValue();
-
- if (maybeConstValue) {
- return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(),
- *maybeConstValue);
- }
- return {};
- };
- SmallVector<Attribute> inferredConsts(
- llvm::map_range(llvm::enumerate(operands), toConstantAttr));
- if (Block *singleSucc = branch.getSuccessorForOperands(inferredConsts)) {
- successors.push_back(singleSucc);
- return success();
- }
- return failure();
-}
-
-void detail::IntRangeAnalysisImpl::getSuccessorsForOperands(
- RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands,
- SmallVectorImpl<RegionSuccessor> &successors) {
- // Get a type with which to construct a constant.
- auto getOperandType = [branch, sourceIndex](unsigned index) {
- // The types of all return-like operations are the same.
- if (!sourceIndex)
- return branch->getOperand(index).getType();
-
- for (Block &block : branch->getRegion(*sourceIndex)) {
- Operation *terminator = block.getTerminator();
- if (getRegionBranchSuccessorOperands(terminator, *sourceIndex))
- return terminator->getOperand(index).getType();
- }
- return Type();
- };
-
- auto toConstantAttr = [&getOperandType](auto enumPair) -> Attribute {
- if (Optional<APInt> maybeConstValue =
- enumPair.value()->getValue().value.getConstantValue()) {
- return IntegerAttr::get(getOperandType(enumPair.index()),
- *maybeConstValue);
- }
- return {};
- };
- SmallVector<Attribute> inferredConsts(
- llvm::map_range(llvm::enumerate(operands), toConstantAttr));
- branch.getSuccessorRegions(sourceIndex, inferredConsts, successors);
-}
-
-ChangeResult detail::IntRangeAnalysisImpl::visitNonControlFlowArguments(
- Operation *op, const RegionSuccessor ®ion,
- ArrayRef<LatticeElement<IntRangeLattice> *> operands) {
- if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) {
- LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for ");
- LLVM_DEBUG(inferrable->print(llvm::dbgs()));
- LLVM_DEBUG(llvm::dbgs() << "\n");
- SmallVector<ConstantIntRanges> argRanges(
- llvm::map_range(operands, [](LatticeElement<IntRangeLattice> *val) {
- return val->getValue().value;
- }));
-
- ChangeResult result = ChangeResult::NoChange;
- auto joinCallback = [&](Value v, const ConstantIntRanges &attrs) {
- LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n");
- LatticeElement<IntRangeLattice> &lattice = getLatticeElement(v);
- Optional<IntRangeLattice> oldRange;
- if (!lattice.isUninitialized())
- oldRange = lattice.getValue();
- result |= lattice.join(IntRangeLattice(attrs));
-
- // Catch loop results with loop variant bounds and conservatively make
- // them [-inf, inf] so we don't circle around infinitely often (because
- // the dataflow analysis in MLIR doesn't attempt to work out trip counts
- // and often can't).
- bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) {
- return op->hasTrait<OpTrait::IsTerminator>();
- });
- if (isYieldedValue && oldRange && !(lattice.getValue() == *oldRange)) {
- LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n");
- result |= lattice.markPessimisticFixpoint();
- }
- };
-
- inferrable.inferResultRanges(argRanges, joinCallback);
- for (Value regionArg : region.getSuccessor()->getArguments()) {
- LatticeElement<IntRangeLattice> &lattice = getLatticeElement(regionArg);
- // setResultRange() not called, make pessimistic.
- if (lattice.isUninitialized())
- result |= lattice.markPessimisticFixpoint();
- }
-
- return result;
- }
-
- // Infer bounds for loop arguments that have static bounds
- if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
- Optional<Value> iv = loop.getSingleInductionVar();
- if (!iv) {
- return ForwardDataFlowAnalysis<
- IntRangeLattice>::visitNonControlFlowArguments(op, region, operands);
- }
- Optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
- Optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
- Optional<OpFoldResult> step = loop.getSingleStep();
- APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), *this,
- /*getUpper=*/false);
- APInt max = getLoopBoundFromFold(upperBound, iv->getType(), *this,
- /*getUpper=*/true);
- // Assume positivity for uniscoverable steps by way of getUpper = true.
- APInt stepVal =
- getLoopBoundFromFold(step, iv->getType(), *this, /*getUpper=*/true);
-
- if (stepVal.isNegative()) {
- std::swap(min, max);
- } else {
- // Correct the upper bound by subtracting 1 so that it becomes a <= bound,
- // because loops do not generally include their upper bound.
- max -= 1;
- }
-
- LatticeElement<IntRangeLattice> &ivEntry = getLatticeElement(*iv);
- return ivEntry.join(ConstantIntRanges::fromSigned(min, max));
- }
- return ForwardDataFlowAnalysis<IntRangeLattice>::visitNonControlFlowArguments(
- op, region, operands);
-}
-
-IntRangeAnalysis::IntRangeAnalysis(Operation *topLevelOperation) {
- impl = std::make_unique<mlir::detail::IntRangeAnalysisImpl>(
- topLevelOperation->getContext());
- impl->run(topLevelOperation);
-}
-
-IntRangeAnalysis::~IntRangeAnalysis() = default;
-IntRangeAnalysis::IntRangeAnalysis(IntRangeAnalysis &&other) = default;
-
-Optional<ConstantIntRanges> IntRangeAnalysis::getResult(Value v) {
- LatticeElement<IntRangeLattice> *result = impl->lookupLatticeElement(v);
- if (result == nullptr || result->isUninitialized())
- return llvm::None;
- return result->getValue().value;
-}
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
index f84990d0a8c47..82e442851b310 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp
@@ -9,33 +9,34 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
-#include "mlir/Analysis/IntRangeAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace mlir::arith;
+using namespace mlir::dataflow;
/// Succeeds when a value is statically non-negative in that it has a lower
/// bound on its value (if it is treated as signed) and that bound is
/// non-negative.
-static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
- Value v) {
- Optional<ConstantIntRanges> result = analysis.getResult(v);
- if (!result.hasValue())
+static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
+ auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
+ if (!result)
return failure();
- const ConstantIntRanges &range = result.getValue();
+ const ConstantIntRanges &range = result->getValue().getValue();
return success(range.smin().isNonNegative());
}
/// Succeeds if an op can be converted to its unsigned equivalent without
/// changing its semantics. This is the case when none of its openands or
/// results can be below 0 when analyzed from a signed perspective.
-static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
+static LogicalResult staticallyNonNegative(DataFlowSolver &solver,
Operation *op) {
- auto nonNegativePred = [&analysis](Value v) -> bool {
- return succeeded(staticallyNonNegative(analysis, v));
+ auto nonNegativePred = [&solver](Value v) -> bool {
+ return succeeded(staticallyNonNegative(solver, v));
};
return success(llvm::all_of(op->getOperands(), nonNegativePred) &&
llvm::all_of(op->getResults(), nonNegativePred));
@@ -44,15 +45,15 @@ static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis,
/// Succeeds when the comparison predicate is a signed operation and all the
/// operands are non-negative, indicating that the cmpi operation `op` can have
/// its predicate changed to an unsigned equivalent.
-static LogicalResult isCmpIConvertable(IntRangeAnalysis &analysis, CmpIOp op) {
+static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
CmpIPredicate pred = op.getPredicate();
switch (pred) {
case CmpIPredicate::sle:
case CmpIPredicate::slt:
case CmpIPredicate::sge:
case CmpIPredicate::sgt:
- return success(llvm::all_of(op.getOperands(), [&analysis](Value v) -> bool {
- return succeeded(staticallyNonNegative(analysis, v));
+ return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
+ return succeeded(staticallyNonNegative(solver, v));
}));
default:
return failure();
@@ -109,19 +110,23 @@ struct ArithmeticUnsignedWhenEquivalentPass
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
- IntRangeAnalysis analysis(op);
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<IntegerRangeAnalysis>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
ConversionTarget target(*ctx);
target.addLegalDialect<ArithmeticDialect>();
target
.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, CeilDivUIOp, FloorDivSIOp,
RemSIOp, MinSIOp, MaxSIOp, ExtSIOp>(
- [&analysis](Operation *op) -> Optional<bool> {
- return failed(staticallyNonNegative(analysis, op));
+ [&solver](Operation *op) -> Optional<bool> {
+ return failed(staticallyNonNegative(solver, op));
});
target.addDynamicallyLegalOp<CmpIOp>(
- [&analysis](CmpIOp op) -> Optional<bool> {
- return failed(isCmpIConvertable(analysis, op));
+ [&solver](CmpIOp op) -> Optional<bool> {
+ return failed(isCmpIConvertable(solver, op));
});
RewritePatternSet patterns(ctx);
diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 902ef880364ee..a0a06352a0e80 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -38,7 +38,7 @@ static LogicalResult replaceWithConstant(DataFlowSolver &solver,
OpBuilder &builder,
OperationFolder &folder, Value value) {
auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
- if (!lattice)
+ if (!lattice || lattice->isUninitialized())
return failure();
const ConstantValue &latticeValue = lattice->getValue();
if (!latticeValue.getConstantValue())
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
index 8106c94d57368..27e994cce3b64 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp
@@ -68,9 +68,8 @@ struct ConstantAnalysis : public DataFlowAnalysis {
LogicalResult initialize(Operation *top) override {
WalkResult result = top->walk([&](Operation *op) {
- if (op->hasTrait<OpTrait::ConstantLike>())
- if (failed(visit(op)))
- return WalkResult::interrupt();
+ if (failed(visit(op)))
+ return WalkResult::interrupt();
return WalkResult::advance();
});
return success(!result.wasInterrupted());
@@ -83,13 +82,27 @@ struct ConstantAnalysis : public DataFlowAnalysis {
auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
propagateIfChanged(
constant, constant->join(ConstantValue(value, op->getDialect())));
+ return success();
}
+ markAllPessimisticFixpoint(op->getResults());
+ for (Region ®ion : op->getRegions())
+ markAllPessimisticFixpoint(region.getArguments());
return success();
}
+
+ /// Mark the constant values of all given values as having reached a
+ /// pessimistic fixpoint.
+ void markAllPessimisticFixpoint(ValueRange values) {
+ for (Value value : values) {
+ auto *constantValue = getOrCreate<Lattice<ConstantValue>>(value);
+ propagateIfChanged(constantValue,
+ constantValue->markPessimisticFixpoint());
+ }
+ }
};
-/// This is a simple pass that runs dead code analysis with no constant value
-/// provider. It marks everything as live.
+/// This is a simple pass that runs dead code analysis with a constant value
+/// provider that only understands constant operations.
struct TestDeadCodeAnalysisPass
: public PassWrapper<TestDeadCodeAnalysisPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass)
diff --git a/mlir/test/lib/Transforms/TestIntRangeInference.cpp b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
index 1bd2a24d3ce6c..70a569d65f6d2 100644
--- a/mlir/test/lib/Transforms/TestIntRangeInference.cpp
+++ b/mlir/test/lib/Transforms/TestIntRangeInference.cpp
@@ -9,7 +9,8 @@
// functionality has been integrated into SCCP.
//===----------------------------------------------------------------------===//
-#include "mlir/Analysis/IntRangeAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
@@ -17,15 +18,17 @@
#include "mlir/Transforms/FoldUtils.h"
using namespace mlir;
+using namespace mlir::dataflow;
/// Patterned after SCCP
-static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
- OpBuilder &b, OperationFolder &folder,
- Value value) {
- Optional<ConstantIntRanges> maybeInferredRange = analysis.getResult(value);
- if (!maybeInferredRange)
+static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
+ OperationFolder &folder, Value value) {
+ auto *maybeInferredRange =
+ solver.lookupState<IntegerValueRangeLattice>(value);
+ if (!maybeInferredRange || maybeInferredRange->isUninitialized())
return failure();
- const ConstantIntRanges &inferredRange = maybeInferredRange.getValue();
+ const ConstantIntRanges &inferredRange =
+ maybeInferredRange->getValue().getValue();
Optional<APInt> maybeConstValue = inferredRange.getConstantValue();
if (!maybeConstValue.hasValue())
return failure();
@@ -44,7 +47,7 @@ static LogicalResult replaceWithConstant(IntRangeAnalysis &analysis,
return success();
}
-static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
+static void rewrite(DataFlowSolver &solver, MLIRContext *context,
MutableArrayRef<Region> initialRegions) {
SmallVector<Block *> worklist;
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
@@ -67,7 +70,7 @@ static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
bool replacedAll = op.getNumResults() != 0;
for (Value res : op.getResults())
replacedAll &=
- succeeded(replaceWithConstant(analysis, builder, folder, res));
+ succeeded(replaceWithConstant(solver, builder, folder, res));
// If all of the results of the operation were replaced, try to erase
// the operation completely.
@@ -84,7 +87,7 @@ static void rewrite(IntRangeAnalysis &analysis, MLIRContext *context,
// Replace any block arguments with constants.
builder.setInsertionPointToStart(block);
for (BlockArgument arg : block->getArguments())
- (void)replaceWithConstant(analysis, builder, folder, arg);
+ (void)replaceWithConstant(solver, builder, folder, arg);
}
}
@@ -100,8 +103,12 @@ struct TestIntRangeInference
void runOnOperation() override {
Operation *op = getOperation();
- IntRangeAnalysis analysis(op);
- rewrite(analysis, op->getContext(), op->getRegions());
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<IntegerRangeAnalysis>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
+ rewrite(solver, op->getContext(), op->getRegions());
}
};
} // end anonymous namespace
More information about the Mlir-commits
mailing list