[Mlir-commits] [mlir] 9432fbf - [mlir] An implementation of sparse data-flow analysis
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jul 7 10:17:09 PDT 2022
Author: Mogball
Date: 2022-07-07T10:17:04-07:00
New Revision: 9432fbfe1327da454c34bd1b6ed448fd58d56e22
URL: https://github.com/llvm/llvm-project/commit/9432fbfe1327da454c34bd1b6ed448fd58d56e22
DIFF: https://github.com/llvm/llvm-project/commit/9432fbfe1327da454c34bd1b6ed448fd58d56e22.diff
LOG: [mlir] An implementation of sparse data-flow analysis
This patch introduces a (forward) sparse data-flow analysis implemented with the data-flow analysis framework. The analysis interacts with liveness information that can be provided by dead-code analysis to be conditional. This patch re-implements SCCP using dead-code analysis and (conditional) constant propagation analyses.
Depends on D127064
Reviewed By: rriddle, phisiart
Differential Revision: https://reviews.llvm.org/D127139
Added:
Modified:
mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Transforms/SCCP.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
index 5e04516e310b1..b07994755c03f 100644
--- a/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h
@@ -60,6 +60,24 @@ class ConstantValue {
Dialect *dialect;
};
+//===----------------------------------------------------------------------===//
+// SparseConstantPropagation
+//===----------------------------------------------------------------------===//
+
+/// This analysis implements sparse constant propagation, which attempts to
+/// determine constant-valued results for operations using constant-valued
+/// operands, by speculatively folding operations. When combined with dead-code
+/// analysis, this becomes sparse conditional constant propagation (SCCP).
+class SparseConstantPropagation
+ : public SparseDataFlowAnalysis<Lattice<ConstantValue>> {
+public:
+ using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
+
+ void visitOperation(Operation *op,
+ ArrayRef<const Lattice<ConstantValue> *> operands,
+ ArrayRef<Lattice<ConstantValue> *> results) override;
+};
+
} // end namespace dataflow
} // end namespace mlir
diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
index 0e75bc50def1a..4c8689c271db5 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h
@@ -89,6 +89,9 @@ class Executable : public AnalysisState {
/// the predecessor to its entry block, and the exiting terminator or a callable
/// operation can be the predecessor of the call operation.
///
+/// The state can optionally contain information about which values are
+/// propagated from each predecessor to the successor point.
+///
/// The state can indicate that it is underdefined, meaning that not all live
/// control-flow predecessors can be known.
class PredecessorState : public AnalysisState {
@@ -118,12 +121,17 @@ class PredecessorState : public AnalysisState {
return knownPredecessors.getArrayRef();
}
- /// Add a known predecessor.
- ChangeResult join(Operation *predecessor) {
- return knownPredecessors.insert(predecessor) ? ChangeResult::Change
- : ChangeResult::NoChange;
+ /// Get the successor inputs from a predecessor.
+ ValueRange getSuccessorInputs(Operation *predecessor) const {
+ return successorInputs.lookup(predecessor);
}
+ /// Add a known predecessor.
+ ChangeResult join(Operation *predecessor);
+
+ /// Add a known predecessor with successor inputs.
+ ChangeResult join(Operation *predecessor, ValueRange inputs);
+
private:
/// Whether all predecessors are known. Optimistically assume that we know
/// all predecessors.
@@ -133,6 +141,9 @@ class PredecessorState : public AnalysisState {
SetVector<Operation *, SmallVector<Operation *, 4>,
SmallPtrSet<Operation *, 4>>
knownPredecessors;
+
+ /// The successor inputs when branching from a given predecessor.
+ DenseMap<Operation *, ValueRange> successorInputs;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index f15416a9cfad2..5907da0ef8da1 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -16,6 +16,7 @@
#define MLIR_ANALYSIS_DATAFLOW_SPARSEANALYSIS_H
#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"
namespace mlir {
@@ -179,6 +180,137 @@ class Lattice : public AbstractSparseLattice {
Optional<ValueT> optimisticValue;
};
+//===----------------------------------------------------------------------===//
+// AbstractSparseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Base class for sparse (forward) data-flow analyses. A sparse analysis
+/// implements a transfer function on operations from the lattices of the
+/// operands to the lattices of the results. This analysis will propagate
+/// lattices across control-flow edges and the callgraph using liveness
+/// information.
+class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
+public:
+ /// Initialize the analysis by visiting every owner of an SSA value: all
+ /// operations and blocks.
+ LogicalResult initialize(Operation *top) override;
+
+ /// Visit a program point. If this is a block and all control-flow
+ /// predecessors or callsites are known, then the arguments lattices are
+ /// propagated from them. If this is a call operation or an operation with
+ /// region control-flow, then its result lattices are set accordingly.
+ /// Otherwise, the operation transfer function is invoked.
+ LogicalResult visit(ProgramPoint point) override;
+
+protected:
+ explicit AbstractSparseDataFlowAnalysis(DataFlowSolver &solver);
+
+ /// The operation transfer function. Given the operand lattices, this
+ /// function is expected to set the result lattices.
+ virtual void
+ visitOperationImpl(Operation *op,
+ ArrayRef<const AbstractSparseLattice *> operandLattices,
+ ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
+
+ /// Get the lattice element of a value.
+ virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
+
+ /// Get a read-only lattice element for a value and add it as a dependency to
+ /// a program point.
+ const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
+ Value value);
+
+ /// Mark the given lattice elements as having reached their pessimistic
+ /// fixpoints and propagate an update if any changed.
+ void markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice *> lattices);
+
+ /// Join the lattice element and propagate and update if it changed.
+ void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
+
+private:
+ /// Recursively initialize the analysis on nested operations and blocks.
+ LogicalResult initializeRecursively(Operation *op);
+
+ /// Visit an operation. If this is a call operation or an operation with
+ /// region control-flow, then its result lattices are set accordingly.
+ /// Otherwise, the operation transfer function is invoked.
+ void visitOperation(Operation *op);
+
+ /// Visit a block to compute the lattice values of its arguments. If this is
+ /// an entry block, then the argument values are determined from the block's
+ /// "predecessors" as set by `PredecessorState`. The predecessors can be
+ /// region terminators or callable callsites. Otherwise, the values are
+ /// determined from block predecessors.
+ void visitBlock(Block *block);
+
+ /// Visit a program point `point` with predecessors within a region branch
+ /// operation `branch`, which can either be the entry block of one of the
+ /// regions or the parent operation itself, and set either the argument or
+ /// parent result lattices.
+ void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
+ Optional<unsigned> successorIndex,
+ ArrayRef<AbstractSparseLattice *> lattices);
+};
+
+//===----------------------------------------------------------------------===//
+// SparseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// A sparse (forward) data-flow analysis for propagating SSA value lattices
+/// across the IR by implementing transfer functions for operations.
+///
+/// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
+template <typename StateT>
+class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
+ static_assert(
+ std::is_base_of<AbstractSparseLattice, StateT>::value,
+ "analysis state class expected to subclass AbstractSparseLattice");
+
+public:
+ explicit SparseDataFlowAnalysis(DataFlowSolver &solver)
+ : AbstractSparseDataFlowAnalysis(solver) {}
+
+ /// Visit an operation with the lattices of its operands. This function is
+ /// expected to set the lattices of the operation's results.
+ virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
+ ArrayRef<StateT *> results) = 0;
+
+protected:
+ /// Get the lattice element for a value.
+ StateT *getLatticeElement(Value value) override {
+ return getOrCreate<StateT>(value);
+ }
+
+ /// Get the lattice element for a value and create a dependency on the
+ /// provided program point.
+ const StateT *getLatticeElementFor(ProgramPoint point, Value value) {
+ return static_cast<const StateT *>(
+ AbstractSparseDataFlowAnalysis::getLatticeElementFor(point, value));
+ }
+
+ /// Mark the lattice elements of a range of values as having reached their
+ /// pessimistic fixpoint.
+ void markAllPessimisticFixpoint(ArrayRef<StateT *> lattices) {
+ AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
+ {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
+ lattices.size()});
+ }
+
+private:
+ /// Type-erased wrappers that convert the abstract lattice operands to derived
+ /// lattices and invoke the virtual hooks operating on the derived lattices.
+ void visitOperationImpl(
+ Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
+ ArrayRef<AbstractSparseLattice *> resultLattices) override {
+ visitOperation(
+ op,
+ {reinterpret_cast<const StateT *const *>(operandLattices.begin()),
+ operandLattices.size()},
+ {reinterpret_cast<StateT *const *>(resultLattices.begin()),
+ resultLattices.size()});
+ }
+};
+
} // end namespace dataflow
} // end namespace mlir
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index f97f8ccdb8649..386237e47b5e7 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -7,6 +7,10 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/IR/OpDefinition.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "constant-propagation"
using namespace mlir;
using namespace mlir::dataflow;
@@ -20,3 +24,68 @@ void ConstantValue::print(raw_ostream &os) const {
return constant.print(os);
os << "<NO VALUE>";
}
+
+//===----------------------------------------------------------------------===//
+// SparseConstantPropagation
+//===----------------------------------------------------------------------===//
+
+void SparseConstantPropagation::visitOperation(
+ Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
+ ArrayRef<Lattice<ConstantValue> *> results) {
+ LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
+
+ // Don't try to simulate the results of a region operation as we can't
+ // guarantee that folding will be out-of-place. We don't allow in-place
+ // folds as the desire here is for simulated execution, and not general
+ // folding.
+ if (op->getNumRegions())
+ return;
+
+ SmallVector<Attribute, 8> constantOperands;
+ constantOperands.reserve(op->getNumOperands());
+ for (auto *operandLattice : operands)
+ constantOperands.push_back(operandLattice->getValue().getConstantValue());
+
+ // Save the original operands and attributes just in case the operation
+ // folds in-place. The constant passed in may not correspond to the real
+ // runtime value, so in-place updates are not allowed.
+ SmallVector<Value, 8> originalOperands(op->getOperands());
+ DictionaryAttr originalAttrs = op->getAttrDictionary();
+
+ // Simulate the result of folding this operation to a constant. If folding
+ // fails or was an in-place fold, mark the results as overdefined.
+ SmallVector<OpFoldResult, 8> foldResults;
+ foldResults.reserve(op->getNumResults());
+ if (failed(op->fold(constantOperands, foldResults))) {
+ markAllPessimisticFixpoint(results);
+ return;
+ }
+
+ // If the folding was in-place, mark the results as overdefined and reset
+ // the operation. We don't allow in-place folds as the desire here is for
+ // simulated execution, and not general folding.
+ if (foldResults.empty()) {
+ op->setOperands(originalOperands);
+ op->setAttrs(originalAttrs);
+ return;
+ }
+
+ // Merge the fold results into the lattice for this operation.
+ assert(foldResults.size() == op->getNumResults() && "invalid result size");
+ for (const auto it : llvm::zip(results, foldResults)) {
+ Lattice<ConstantValue> *lattice = std::get<0>(it);
+
+ // Merge in the result of the fold, either a constant or a value.
+ OpFoldResult foldResult = std::get<1>(it);
+ if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
+ LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
+ propagateIfChanged(lattice,
+ lattice->join(ConstantValue(attr, op->getDialect())));
+ } else {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Folded to value: " << foldResult.get<Value>() << "\n");
+ AbstractSparseDataFlowAnalysis::join(
+ lattice, *getLatticeElement(foldResult.get<Value>()));
+ }
+ }
+}
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 49319c2a33b7f..1035c21219e7a 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -59,6 +59,23 @@ void PredecessorState::print(raw_ostream &os) const {
os << " " << *op << "\n";
}
+ChangeResult PredecessorState::join(Operation *predecessor) {
+ return knownPredecessors.insert(predecessor) ? ChangeResult::Change
+ : ChangeResult::NoChange;
+}
+
+ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) {
+ ChangeResult result = join(predecessor);
+ if (!inputs.empty()) {
+ ValueRange &curInputs = successorInputs[predecessor];
+ if (curInputs != inputs) {
+ curInputs = inputs;
+ result |= ChangeResult::Change;
+ }
+ }
+ return result;
+}
+
//===----------------------------------------------------------------------===//
// CFGEdge
//===----------------------------------------------------------------------===//
@@ -333,14 +350,18 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
SmallVector<RegionSuccessor> successors;
branch.getSuccessorRegions(/*index=*/{}, *operands, successors);
for (const RegionSuccessor &successor : successors) {
+ // The successor can be either an entry block or the parent operation.
+ ProgramPoint point = successor.getSuccessor()
+ ? &successor.getSuccessor()->front()
+ : ProgramPoint(branch);
// Mark the entry block as executable.
- Region *region = successor.getSuccessor();
- assert(region && "expected a region successor");
- auto *state = getOrCreate<Executable>(®ion->front());
+ auto *state = getOrCreate<Executable>(point);
propagateIfChanged(state, state->setToLive());
// Add the parent op as a predecessor.
- auto *predecessors = getOrCreate<PredecessorState>(®ion->front());
- propagateIfChanged(predecessors, predecessors->join(branch));
+ auto *predecessors = getOrCreate<PredecessorState>(point);
+ propagateIfChanged(
+ predecessors,
+ predecessors->join(branch, successor.getSuccessorInputs()));
}
}
@@ -366,7 +387,8 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
// Add this terminator as a predecessor to the parent op.
predecessors = getOrCreate<PredecessorState>(branch);
}
- propagateIfChanged(predecessors, predecessors->join(op));
+ propagateIfChanged(predecessors,
+ predecessors->join(op, successor.getSuccessorInputs()));
}
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index e2640280fffd2..35487c1e1de82 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -7,6 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
using namespace mlir;
using namespace mlir::dataflow;
@@ -21,3 +24,265 @@ void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
for (DataFlowAnalysis *analysis : useDefSubscribers)
solver->enqueue({user, analysis});
}
+
+//===----------------------------------------------------------------------===//
+// AbstractSparseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis(
+ DataFlowSolver &solver)
+ : DataFlowAnalysis(solver) {
+ registerPointKind<CFGEdge>();
+}
+
+LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) {
+ // Mark the entry block arguments as having reached their pessimistic
+ // fixpoints.
+ for (Region ®ion : top->getRegions()) {
+ if (region.empty())
+ continue;
+ for (Value argument : region.front().getArguments())
+ markAllPessimisticFixpoint(getLatticeElement(argument));
+ }
+
+ return initializeRecursively(top);
+}
+
+LogicalResult
+AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
+ // Initialize the analysis by visiting every owner of an SSA value (all
+ // operations and blocks).
+ visitOperation(op);
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region) {
+ getOrCreate<Executable>(&block)->blockContentSubscribe(this);
+ visitBlock(&block);
+ for (Operation &op : block)
+ if (failed(initializeRecursively(&op)))
+ return failure();
+ }
+ }
+
+ return success();
+}
+
+LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
+ if (Operation *op = point.dyn_cast<Operation *>())
+ visitOperation(op);
+ else if (Block *block = point.dyn_cast<Block *>())
+ visitBlock(block);
+ else
+ return failure();
+ return success();
+}
+
+void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
+ // Exit early on operations with no results.
+ if (op->getNumResults() == 0)
+ return;
+
+ // If the containing block is not executable, bail out.
+ if (!getOrCreate<Executable>(op->getBlock())->isLive())
+ return;
+
+ // Get the result lattices.
+ SmallVector<AbstractSparseLattice *> resultLattices;
+ resultLattices.reserve(op->getNumResults());
+ // Track whether all results have reached their fixpoint.
+ bool allAtFixpoint = true;
+ for (Value result : op->getResults()) {
+ AbstractSparseLattice *resultLattice = getLatticeElement(result);
+ allAtFixpoint &= resultLattice->isAtFixpoint();
+ resultLattices.push_back(resultLattice);
+ }
+ // If all result lattices have reached a fixpoint, there is nothing to do.
+ if (allAtFixpoint)
+ return;
+
+ // The results of a region branch operation are determined by control-flow.
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+ return visitRegionSuccessors({branch}, branch,
+ /*successorIndex=*/llvm::None, resultLattices);
+ }
+
+ // The results of a call operation are determined by the callgraph.
+ if (auto call = dyn_cast<CallOpInterface>(op)) {
+ const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
+ // If not all return sites are known, then conservatively assume we can't
+ // reason about the data-flow.
+ if (!predecessors->allPredecessorsKnown())
+ return markAllPessimisticFixpoint(resultLattices);
+ for (Operation *predecessor : predecessors->getKnownPredecessors())
+ for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
+ join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
+ return;
+ }
+
+ // Grab the lattice elements of the operands.
+ SmallVector<const AbstractSparseLattice *> operandLattices;
+ operandLattices.reserve(op->getNumOperands());
+ for (Value operand : op->getOperands()) {
+ AbstractSparseLattice *operandLattice = getLatticeElement(operand);
+ operandLattice->useDefSubscribe(this);
+ // If any of the operand states are not initialized, bail out.
+ if (operandLattice->isUninitialized())
+ return;
+ operandLattices.push_back(operandLattice);
+ }
+
+ // Invoke the operation transfer function.
+ visitOperationImpl(op, operandLattices, resultLattices);
+}
+
+void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
+ // Exit early on blocks with no arguments.
+ if (block->getNumArguments() == 0)
+ return;
+
+ // If the block is not executable, bail out.
+ if (!getOrCreate<Executable>(block)->isLive())
+ return;
+
+ // Get the argument lattices.
+ SmallVector<AbstractSparseLattice *> argLattices;
+ argLattices.reserve(block->getNumArguments());
+ bool allAtFixpoint = true;
+ for (BlockArgument argument : block->getArguments()) {
+ AbstractSparseLattice *argLattice = getLatticeElement(argument);
+ allAtFixpoint &= argLattice->isAtFixpoint();
+ argLattices.push_back(argLattice);
+ }
+ // If all argument lattices have reached their fixpoints, then there is
+ // nothing to do.
+ if (allAtFixpoint)
+ return;
+
+ // The argument lattices of entry blocks are set by region control-flow or the
+ // callgraph.
+ if (block->isEntryBlock()) {
+ // Check if this block is the entry block of a callable region.
+ auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
+ if (callable && callable.getCallableRegion() == block->getParent()) {
+ const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
+ // If not all callsites are known, conservatively mark all lattices as
+ // having reached their pessimistic fixpoints.
+ if (!callsites->allPredecessorsKnown())
+ return markAllPessimisticFixpoint(argLattices);
+ for (Operation *callsite : callsites->getKnownPredecessors()) {
+ auto call = cast<CallOpInterface>(callsite);
+ for (auto it : llvm::zip(call.getArgOperands(), argLattices))
+ join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
+ }
+ return;
+ }
+
+ // Check if the lattices can be determined from region control flow.
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
+ return visitRegionSuccessors(
+ block, branch, block->getParent()->getRegionNumber(), argLattices);
+ }
+
+ // Otherwise, we can't reason about the data-flow.
+ return markAllPessimisticFixpoint(argLattices);
+ }
+
+ // Iterate over the predecessors of the non-entry block.
+ for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
+ it != e; ++it) {
+ Block *predecessor = *it;
+
+ // If the edge from the predecessor block to the current block is not live,
+ // bail out.
+ auto *edgeExecutable =
+ getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
+ edgeExecutable->blockContentSubscribe(this);
+ if (!edgeExecutable->isLive())
+ continue;
+
+ // Check if we can reason about the data-flow from the predecessor.
+ if (auto branch =
+ dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
+ SuccessorOperands operands =
+ branch.getSuccessorOperands(it.getSuccessorIndex());
+ for (auto &it : llvm::enumerate(argLattices)) {
+ if (Value operand = operands[it.index()]) {
+ join(it.value(), *getLatticeElementFor(block, operand));
+ } else {
+ // Conservatively mark internally produced arguments as having reached
+ // their pessimistic fixpoint.
+ markAllPessimisticFixpoint(it.value());
+ }
+ }
+ } else {
+ return markAllPessimisticFixpoint(argLattices);
+ }
+ }
+}
+
+void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
+ ProgramPoint point, RegionBranchOpInterface branch,
+ Optional<unsigned> successorIndex,
+ ArrayRef<AbstractSparseLattice *> lattices) {
+ const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
+ assert(predecessors->allPredecessorsKnown() &&
+ "unexpected unresolved region successors");
+
+ for (Operation *op : predecessors->getKnownPredecessors()) {
+ // Get the incoming successor operands.
+ Optional<OperandRange> operands;
+
+ // Check if the predecessor is the parent op.
+ if (op == branch) {
+ operands = branch.getSuccessorEntryOperands(successorIndex);
+ // Otherwise, try to deduce the operands from a region return-like op.
+ } else {
+ assert(op->hasTrait<OpTrait::IsTerminator>() && "expected a terminator");
+ if (isRegionReturnLike(op))
+ operands = getRegionBranchSuccessorOperands(op, successorIndex);
+ }
+
+ if (!operands) {
+ // We can't reason about the data-flow.
+ return markAllPessimisticFixpoint(lattices);
+ }
+
+ ValueRange inputs = predecessors->getSuccessorInputs(op);
+ assert(inputs.size() == operands->size() &&
+ "expected the same number of successor inputs as operands");
+
+ // TODO: This was updated to be exposed upstream.
+ unsigned firstIndex = 0;
+ if (inputs.size() != lattices.size()) {
+ if (inputs.empty()) {
+ markAllPessimisticFixpoint(lattices);
+ return;
+ }
+ firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
+ markAllPessimisticFixpoint(lattices.take_front(firstIndex));
+ markAllPessimisticFixpoint(
+ lattices.drop_front(firstIndex + inputs.size()));
+ }
+
+ for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
+ join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
+ }
+}
+
+const AbstractSparseLattice *
+AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
+ Value value) {
+ AbstractSparseLattice *state = getLatticeElement(value);
+ addDependency(state, point);
+ return state;
+}
+
+void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
+ ArrayRef<AbstractSparseLattice *> lattices) {
+ for (AbstractSparseLattice *lattice : lattices)
+ propagateIfChanged(lattice, lattice->markPessimisticFixpoint());
+}
+
+void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
+ const AbstractSparseLattice &rhs) {
+ propagateIfChanged(lhs, lhs->join(rhs));
+}
diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 11d55e7454a0a..902ef880364ee 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -15,150 +15,17 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
-#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
-#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/Passes.h"
-#include "llvm/Support/Debug.h"
-
-#define DEBUG_TYPE "sccp"
using namespace mlir;
-
-//===----------------------------------------------------------------------===//
-// SCCP Analysis
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct SCCPLatticeValue {
- SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr)
- : constant(constant), constantDialect(dialect) {}
-
- /// The pessimistic state of SCCP is non-constant.
- static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) {
- return SCCPLatticeValue();
- }
- static SCCPLatticeValue getPessimisticValueState(Value value) {
- return SCCPLatticeValue();
- }
-
- /// Equivalence for SCCP only accounts for the constant, not the originating
- /// dialect.
- bool operator==(const SCCPLatticeValue &rhs) const {
- return constant == rhs.constant;
- }
-
- /// To join the state of two values, we simply check for equivalence.
- static SCCPLatticeValue join(const SCCPLatticeValue &lhs,
- const SCCPLatticeValue &rhs) {
- return lhs == rhs ? lhs : SCCPLatticeValue();
- }
-
- /// The constant attribute value.
- Attribute constant;
-
- /// The dialect the constant originated from. This is not used as part of the
- /// key, and is only needed to materialize the held constant if necessary.
- Dialect *constantDialect;
-};
-
-struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> {
- using ForwardDataFlowAnalysis<SCCPLatticeValue>::ForwardDataFlowAnalysis;
- ~SCCPAnalysis() override = default;
-
- ChangeResult
- visitOperation(Operation *op,
- ArrayRef<LatticeElement<SCCPLatticeValue> *> operands) final {
-
- LLVM_DEBUG(llvm::dbgs() << "SCCP: Visiting operation: " << *op << "\n");
-
- // Don't try to simulate the results of a region operation as we can't
- // guarantee that folding will be out-of-place. We don't allow in-place
- // folds as the desire here is for simulated execution, and not general
- // folding.
- if (op->getNumRegions())
- return markAllPessimisticFixpoint(op->getResults());
-
- SmallVector<Attribute> constantOperands(
- llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
- return value->getValue().constant;
- }));
-
- // Save the original operands and attributes just in case the operation
- // folds in-place. The constant passed in may not correspond to the real
- // runtime value, so in-place updates are not allowed.
- SmallVector<Value, 8> originalOperands(op->getOperands());
- DictionaryAttr originalAttrs = op->getAttrDictionary();
-
- // Simulate the result of folding this operation to a constant. If folding
- // fails or was an in-place fold, mark the results as overdefined.
- SmallVector<OpFoldResult, 8> foldResults;
- foldResults.reserve(op->getNumResults());
- if (failed(op->fold(constantOperands, foldResults)))
- return markAllPessimisticFixpoint(op->getResults());
-
- // If the folding was in-place, mark the results as overdefined and reset
- // the operation. We don't allow in-place folds as the desire here is for
- // simulated execution, and not general folding.
- if (foldResults.empty()) {
- op->setOperands(originalOperands);
- op->setAttrs(originalAttrs);
- return markAllPessimisticFixpoint(op->getResults());
- }
-
- // Merge the fold results into the lattice for this operation.
- assert(foldResults.size() == op->getNumResults() && "invalid result size");
- Dialect *dialect = op->getDialect();
- ChangeResult result = ChangeResult::NoChange;
- for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
- LatticeElement<SCCPLatticeValue> &lattice =
- getLatticeElement(op->getResult(i));
-
- // Merge in the result of the fold, either a constant or a value.
- OpFoldResult foldResult = foldResults[i];
- if (Attribute attr = foldResult.dyn_cast<Attribute>())
- result |= lattice.join(SCCPLatticeValue(attr, dialect));
- else
- result |= lattice.join(getLatticeElement(foldResult.get<Value>()));
- }
- return result;
- }
-
- /// Implementation of `getSuccessorsForOperands` that uses constant operands
- /// to potentially remove dead successors.
- LogicalResult getSuccessorsForOperands(
- BranchOpInterface branch,
- ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
- SmallVectorImpl<Block *> &successors) final {
- SmallVector<Attribute> constantOperands(
- llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
- return value->getValue().constant;
- }));
- if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
- successors.push_back(singleSucc);
- return success();
- }
- return failure();
- }
-
- /// Implementation of `getSuccessorsForOperands` that uses constant operands
- /// to potentially remove dead region successors.
- void getSuccessorsForOperands(
- RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
- ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
- SmallVectorImpl<RegionSuccessor> &successors) final {
- SmallVector<Attribute> constantOperands(
- llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
- return value->getValue().constant;
- }));
- branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
- }
-};
-} // namespace
+using namespace mlir::dataflow;
//===----------------------------------------------------------------------===//
// SCCP Rewrites
@@ -167,21 +34,21 @@ struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> {
/// Replace the given value with a constant if the corresponding lattice
/// represents a constant. Returns success if the value was replaced, failure
/// otherwise.
-static LogicalResult replaceWithConstant(SCCPAnalysis &analysis,
+static LogicalResult replaceWithConstant(DataFlowSolver &solver,
OpBuilder &builder,
OperationFolder &folder, Value value) {
- LatticeElement<SCCPLatticeValue> *lattice =
- analysis.lookupLatticeElement(value);
+ auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
if (!lattice)
return failure();
- SCCPLatticeValue &latticeValue = lattice->getValue();
- if (!latticeValue.constant)
+ const ConstantValue &latticeValue = lattice->getValue();
+ if (!latticeValue.getConstantValue())
return failure();
// Attempt to materialize a constant for the given value.
- Dialect *dialect = latticeValue.constantDialect;
- Value constant = folder.getOrCreateConstant(
- builder, dialect, latticeValue.constant, value.getType(), value.getLoc());
+ Dialect *dialect = latticeValue.getConstantDialect();
+ Value constant = folder.getOrCreateConstant(builder, dialect,
+ latticeValue.getConstantValue(),
+ value.getType(), value.getLoc());
if (!constant)
return failure();
@@ -192,7 +59,7 @@ static LogicalResult replaceWithConstant(SCCPAnalysis &analysis,
/// Rewrite the given regions using the computing analysis. This replaces the
/// uses of all values that have been computed to be constant, and erases as
/// many newly dead operations.
-static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
+static void rewrite(DataFlowSolver &solver, MLIRContext *context,
MutableArrayRef<Region> initialRegions) {
SmallVector<Block *> worklist;
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
@@ -216,7 +83,7 @@ static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
bool replacedAll = op.getNumResults() != 0;
for (Value res : op.getResults())
replacedAll &=
- succeeded(replaceWithConstant(analysis, builder, folder, res));
+ succeeded(replaceWithConstant(solver, builder, folder, res));
// If all of the results of the operation were replaced, try to erase
// the operation completely.
@@ -233,7 +100,7 @@ static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
// Replace any block arguments with constants.
builder.setInsertionPointToStart(block);
for (BlockArgument arg : block->getArguments())
- (void)replaceWithConstant(analysis, builder, folder, arg);
+ (void)replaceWithConstant(solver, builder, folder, arg);
}
}
@@ -250,9 +117,12 @@ struct SCCP : public SCCPBase<SCCP> {
void SCCP::runOnOperation() {
Operation *op = getOperation();
- SCCPAnalysis analysis(op->getContext());
- analysis.run(op);
- rewrite(analysis, op->getContext(), op->getRegions());
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<SparseConstantPropagation>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
+ rewrite(solver, op->getContext(), op->getRegions());
}
std::unique_ptr<Pass> mlir::createSCCPPass() {
More information about the Mlir-commits
mailing list