[llvm-branch-commits] [mlir] 08316f8 - [mlir] Add Dead Code Analysis
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jun 29 11:00:30 PDT 2022
Author: Mogball
Date: 2022-06-29T10:59:44-07:00
New Revision: 08316f82e5537e4ba8eff1d82ac4eb3b88e20307
URL: https://github.com/llvm/llvm-project/commit/08316f82e5537e4ba8eff1d82ac4eb3b88e20307
DIFF: https://github.com/llvm/llvm-project/commit/08316f82e5537e4ba8eff1d82ac4eb3b88e20307.diff
LOG: [mlir] Add Dead Code Analysis
This patch implements the analysis state classes needed for sparse data-flow analysis and implements a dead-code analysis using those states to determine liveness of blocks, control-flow edges, region predecessors, and function callsites.
Added:
mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
mlir/test/Analysis/test-dead-code-analysis.mlir
mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp
Modified:
mlir/include/mlir/Analysis/DataFlowAnalysis.h
mlir/include/mlir/Analysis/DataFlowFramework.h
mlir/lib/Analysis/CMakeLists.txt
mlir/test/lib/Analysis/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlowAnalysis.h b/mlir/include/mlir/Analysis/DataFlowAnalysis.h
index 883624dfa4805..bc7664b6a20ee 100644
--- a/mlir/include/mlir/Analysis/DataFlowAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlowAnalysis.h
@@ -22,34 +22,17 @@
#ifndef MLIR_ANALYSIS_DATAFLOWANALYSIS_H
#define MLIR_ANALYSIS_DATAFLOWANALYSIS_H
+#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Optional.h"
#include "llvm/Support/Allocator.h"
-namespace mlir {
-//===----------------------------------------------------------------------===//
-// ChangeResult
-//===----------------------------------------------------------------------===//
-
-/// A result type used to indicate if a change happened. Boolean operations on
-/// ChangeResult behave as though `Change` is truthy.
-enum class ChangeResult {
- NoChange,
- Change,
-};
-inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) {
- return lhs == ChangeResult::Change ? lhs : rhs;
-}
-inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) {
- lhs = lhs | rhs;
- return lhs;
-}
-inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
- return lhs == ChangeResult::NoChange ? lhs : rhs;
-}
+/// TODO: Remove this file when SCCP and integer range analysis have been ported
+/// to the new framework.
+namespace mlir {
//===----------------------------------------------------------------------===//
// AbstractLatticeElement
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 9b85182532232..19d8fc0c3e19b 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -16,7 +16,6 @@
#ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
#define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
-#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/StorageUniquer.h"
#include "llvm/ADT/SetVector.h"
@@ -25,6 +24,27 @@
namespace mlir {
+//===----------------------------------------------------------------------===//
+// ChangeResult
+//===----------------------------------------------------------------------===//
+
+/// A result type used to indicate if a change happened. Boolean operations on
+/// ChangeResult behave as though `Change` is truthy.
+enum class ChangeResult {
+ NoChange,
+ Change,
+};
+inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) {
+ return lhs == ChangeResult::Change ? lhs : rhs;
+}
+inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) {
+ lhs = lhs | rhs;
+ return lhs;
+}
+inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
+ return lhs == ChangeResult::NoChange ? lhs : rhs;
+}
+
/// Forward declare the analysis state class.
class AnalysisState;
@@ -137,6 +157,12 @@ struct ProgramPoint
using ParentTy::PointerUnion;
/// Allow implicit conversion from the parent type.
ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
+ /// Allow implicit conversions from operation wrappers.
+ /// TODO: For Windows only. Find a better solution.
+ template <typename OpT, typename = typename std::enable_if_t<
+ std::is_convertible<OpT, Operation *>::value &&
+ !std::is_same<OpT, Operation *>::value>>
+ ProgramPoint(OpT op) : ParentTy(op) {}
/// Print the program point.
void print(raw_ostream &os) const;
@@ -180,7 +206,7 @@ class DataFlowSolver {
/// does not exist.
template <typename StateT, typename PointT>
const StateT *lookupState(PointT point) const {
- auto it = analysisStates.find({point, TypeID::get<StateT>()});
+ auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
if (it == analysisStates.end())
return nullptr;
return static_cast<const StateT *>(it->second.get());
diff --git a/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
new file mode 100644
index 0000000000000..ee9392b387e60
--- /dev/null
+++ b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
@@ -0,0 +1,418 @@
+//===- SparseDataFlowAnalysis.h - Sparse data-flow analysis ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements sparse data-flow analysis using the data-flow analysis
+// framework. The analysis is forward and conditional and uses the results of
+// dead code analysis to prune dead code during the analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H
+#define MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H
+
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// AbstractSparseLattice
+//===----------------------------------------------------------------------===//
+
+/// This class represents an abstract lattice. A lattice contains information
+/// about an SSA value and is what's propagated across the IR by sparse
+/// data-flow analysis.
+class AbstractSparseLattice : public AnalysisState {
+public:
+ /// Lattices can only be created for values.
+ AbstractSparseLattice(Value value) : AnalysisState(value) {}
+
+ /// Join the information contained in 'rhs' into this lattice. Returns
+ /// if the value of the lattice changed.
+ virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0;
+
+ /// Returns true if the lattice element is at fixpoint and further calls to
+ /// `join` will not update the value of the element.
+ virtual bool isAtFixpoint() const = 0;
+
+ /// Mark the lattice element as having reached a pessimistic fixpoint. This
+ /// means that the lattice may potentially have conflicting value states, and
+ /// only the most conservative value should be relied on.
+ virtual ChangeResult markPessimisticFixpoint() = 0;
+
+ /// When the lattice gets updated, propagate an update to users of the value
+ /// using its use-def chain to subscribed analyses.
+ void onUpdate(DataFlowSolver *solver) const override;
+
+ /// Subscribe an analysis to updates of the lattice. When the lattice changes,
+ /// subscribed analyses are re-invoked on all users of the value. This is
+ /// more efficient than relying on the dependency map.
+ void useDefSubscribe(DataFlowAnalysis *analysis) {
+ useDefSubscribers.insert(analysis);
+ }
+
+private:
+ /// A set of analyses that should be updated when this lattice changes.
+ SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
+ SmallPtrSet<DataFlowAnalysis *, 4>>
+ useDefSubscribers;
+};
+
+//===----------------------------------------------------------------------===//
+// Lattice
+//===----------------------------------------------------------------------===//
+
+/// This class represents a lattice holding a specific value of type `ValueT`.
+/// Lattice values (`ValueT`) are required to adhere to the following:
+///
+/// * static ValueT join(const ValueT &lhs, const ValueT &rhs);
+/// - This method conservatively joins the information held by `lhs`
+/// and `rhs` into a new value. This method is required to be monotonic.
+/// * bool operator==(const ValueT &rhs) const;
+///
+template <typename ValueT>
+class Lattice : public AbstractSparseLattice {
+public:
+ using AbstractSparseLattice::AbstractSparseLattice;
+
+ /// Get a lattice element with a known value.
+ Lattice(const ValueT &knownValue = ValueT())
+ : AbstractSparseLattice(Value()), knownValue(knownValue) {}
+
+ /// Return the value held by this lattice. This requires that the value is
+ /// initialized.
+ ValueT &getValue() {
+ assert(!isUninitialized() && "expected known lattice element");
+ return *optimisticValue;
+ }
+ const ValueT &getValue() const {
+ return const_cast<Lattice<ValueT> *>(this)->getValue();
+ }
+
+ /// Returns true if the value of this lattice hasn't yet been initialized.
+ bool isUninitialized() const override { return !optimisticValue.hasValue(); }
+ /// Force the initialization of the element by setting it to its pessimistic
+ /// fixpoint.
+ ChangeResult defaultInitialize() override {
+ return markPessimisticFixpoint();
+ }
+
+ /// Returns true if the lattice has reached a fixpoint. A fixpoint is when
+ /// the information optimistically assumed to be true is the same as the
+ /// information known to be true.
+ bool isAtFixpoint() const override { return optimisticValue == knownValue; }
+
+ /// Join the information contained in the 'rhs' lattice into this
+ /// lattice. Returns if the state of the current lattice changed.
+ ChangeResult join(const AbstractSparseLattice &rhs) override {
+ const Lattice<ValueT> &rhsLattice =
+ static_cast<const Lattice<ValueT> &>(rhs);
+
+ // If we are at a fixpoint, or rhs is uninitialized, there is nothing to do.
+ if (isAtFixpoint() || rhsLattice.isUninitialized())
+ return ChangeResult::NoChange;
+
+ // Join the rhs value into this lattice.
+ return join(rhsLattice.getValue());
+ }
+
+ /// Join the information contained in the 'rhs' value into this
+ /// lattice. Returns if the state of the current lattice changed.
+ ChangeResult join(const ValueT &rhs) {
+ // If the current lattice is uninitialized, copy the rhs value.
+ if (isUninitialized()) {
+ optimisticValue = rhs;
+ return ChangeResult::Change;
+ }
+
+ // Otherwise, join rhs with the current optimistic value.
+ ValueT newValue = ValueT::join(*optimisticValue, rhs);
+ assert(ValueT::join(newValue, *optimisticValue) == newValue &&
+ "expected `join` to be monotonic");
+ assert(ValueT::join(newValue, rhs) == newValue &&
+ "expected `join` to be monotonic");
+
+ // Update the current optimistic value if something changed.
+ if (newValue == optimisticValue)
+ return ChangeResult::NoChange;
+
+ optimisticValue = newValue;
+ return ChangeResult::Change;
+ }
+
+ /// Mark the lattice element as having reached a pessimistic fixpoint. This
+ /// means that the lattice may potentially have conflicting value states,
+ /// and only the conservatively known value state should be relied on.
+ ChangeResult markPessimisticFixpoint() override {
+ if (isAtFixpoint())
+ return ChangeResult::NoChange;
+
+ // For this fixed point, we take whatever we knew to be true and set that
+ // to our optimistic value.
+ optimisticValue = knownValue;
+ return ChangeResult::Change;
+ }
+
+ /// Print the lattice element.
+ void print(raw_ostream &os) const override {
+ os << "[";
+ knownValue.print(os);
+ os << ", ";
+ if (optimisticValue) {
+ optimisticValue->print(os);
+ } else {
+ os << "<NULL>";
+ }
+ os << "]";
+ }
+
+private:
+ /// The value that is conservatively known to be true.
+ ValueT knownValue;
+ /// The currently computed value that is optimistically assumed to be true,
+ /// or None if the lattice element is uninitialized.
+ Optional<ValueT> optimisticValue;
+};
+
+//===----------------------------------------------------------------------===//
+// Executable
+//===----------------------------------------------------------------------===//
+
+/// This is a simple analysis state that represents whether the associated
+/// program point (either a block or a control-flow edge) is live.
+class Executable : public AnalysisState {
+public:
+ using AnalysisState::AnalysisState;
+
+ /// The state is initialized by default.
+ bool isUninitialized() const override { return false; }
+
+ /// The state is always initialized.
+ ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
+
+ /// Set the state of the program point to live.
+ ChangeResult setToLive();
+
+ /// Get whether the program point is live.
+ bool isLive() const { return live; }
+
+ /// Print the liveness;
+ void print(raw_ostream &os) const override;
+
+ /// When the state of the program point is changed to live, re-invoke
+ /// subscribed analyses on the operations in the block and on the block
+ /// itself.
+ void onUpdate(DataFlowSolver *solver) const override;
+
+ /// Subscribe an analysis to changes to the liveness.
+ void blockContentSubscribe(DataFlowAnalysis *analysis) {
+ subscribers.insert(analysis);
+ }
+
+private:
+ /// Whether the program point is live. Optimistically assume that the program
+ /// point is dead.
+ bool live = false;
+
+ /// A set of analyses that should be updated when this state changes.
+ SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
+ SmallPtrSet<DataFlowAnalysis *, 4>>
+ subscribers;
+};
+
+//===----------------------------------------------------------------------===//
+// ConstantValue
+//===----------------------------------------------------------------------===//
+
+/// This lattice value represents a known constant value of a lattice.
+class ConstantValue {
+public:
+ /// Construct a constant value with a known constant.
+ ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr)
+ : constant(knownValue), dialect(dialect) {}
+
+ /// Get the constant value. Returns null if no value was determined.
+ Attribute getConstantValue() const { return constant; }
+
+ /// Get the dialect instance that can be used to materialize the constant.
+ Dialect *getConstantDialect() const { return dialect; }
+
+ /// Compare the constant values.
+ bool operator==(const ConstantValue &rhs) const {
+ return constant == rhs.constant;
+ }
+
+ /// The union with another constant value is null if they are
diff erent, and
+ /// the same if they are the same.
+ static ConstantValue join(const ConstantValue &lhs,
+ const ConstantValue &rhs) {
+ return lhs == rhs ? lhs : ConstantValue();
+ }
+
+ /// Print the constant value.
+ void print(raw_ostream &os) const;
+
+private:
+ /// The constant value.
+ Attribute constant;
+ /// An dialect instance that can be used to materialize the constant.
+ Dialect *dialect;
+};
+
+//===----------------------------------------------------------------------===//
+// PredecessorState
+//===----------------------------------------------------------------------===//
+
+/// This analysis state represents a set of known predecessors. This state is
+/// used in sparse data-flow analysis to reason about region control-flow and
+/// callgraphs. The state may also indicate that not all predecessors can be
+/// known, if for example not all callsites of a callable are visible.
+class PredecessorState : public AnalysisState {
+public:
+ using AnalysisState::AnalysisState;
+
+ /// The state is initialized by default.
+ bool isUninitialized() const override { return false; }
+
+ /// The state is always initialized.
+ ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
+
+ /// Print the known predecessors.
+ void print(raw_ostream &os) const override;
+
+ /// Returns true if all predecessors are known.
+ bool allPredecessorsKnown() const { return allKnown; }
+
+ /// Indicate that there are potentially unknown predecessors.
+ ChangeResult setHasUnknownPredecessors() {
+ if (!allKnown)
+ return ChangeResult::NoChange;
+ allKnown = false;
+ return ChangeResult::Change;
+ }
+
+ /// Get the known predecessors.
+ ArrayRef<Operation *> getKnownPredecessors() const {
+ return knownPredecessors.getArrayRef();
+ }
+
+ /// Add a known predecessor.
+ ChangeResult join(Operation *predecessor) {
+ return knownPredecessors.insert(predecessor) ? ChangeResult::Change
+ : ChangeResult::NoChange;
+ }
+
+private:
+ /// Whether all predecessors are known. Optimistically assume that we know
+ /// all predecessors.
+ bool allKnown = true;
+
+ /// The known control-flow predecessors of this program point.
+ SetVector<Operation *, SmallVector<Operation *, 4>,
+ SmallPtrSet<Operation *, 4>>
+ knownPredecessors;
+};
+
+//===----------------------------------------------------------------------===//
+// CFGEdge
+//===----------------------------------------------------------------------===//
+
+/// This program point represents a control-flow edge between a block and one
+/// of its successors.
+class CFGEdge
+ : public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
+public:
+ using Base::Base;
+
+ /// Get the block from which the edge originates.
+ Block *getFrom() const { return getValue().first; }
+ /// Get the target block.
+ Block *getTo() const { return getValue().second; }
+
+ /// Print the blocks between the control-flow edge.
+ void print(raw_ostream &os) const override;
+ /// Get a fused location of both blocks.
+ Location getLoc() const override;
+};
+
+//===----------------------------------------------------------------------===//
+// DeadCodeAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Dead code analysis analyzes control-flow, as understood by
+/// `RegionBranchOpInterface` and `BranchOpInterface`, and the callgraph, as
+/// understood by `CallableOpInterface` and `CallOpInterface`.
+///
+/// This analysis uses known constant values of operands to determine the
+/// liveness of each block and each edge between a block and its predecessors.
+/// For region control-flow, this analysis determines the predecessor operations
+/// for region entry blocks and region control-flow operations. For the
+/// callgraph, this analysis determines the callsites and live returns of every
+/// function.
+class DeadCodeAnalysis : public DataFlowAnalysis {
+public:
+ explicit DeadCodeAnalysis(DataFlowSolver &solver);
+
+ /// Initialize the analysis by visiting every operation with potential
+ /// control-flow semantics.
+ LogicalResult initialize(Operation *top) override;
+
+ /// Visit an operation with control-flow semantics and deduce which of its
+ /// successors are live.
+ LogicalResult visit(ProgramPoint point) override;
+
+private:
+ /// Find and mark symbol callables with potentially unknown callsites as
+ /// having overdefined predecessors. `top` is the top-level operation that the
+ /// analysis is operating on.
+ void initializeSymbolCallables(Operation *top);
+
+ /// Recursively Initialize the analysis on nested regions.
+ LogicalResult initializeRecursively(Operation *op);
+
+ /// Visit the given call operation and compute any necessary lattice state.
+ void visitCallOperation(CallOpInterface call);
+
+ /// Visit the given branch operation with successors and try to determine
+ /// which are live from the current block.
+ void visitBranchOperation(BranchOpInterface branch);
+
+ /// Visit the given region branch operation, which defines regions, and
+ /// compute any necessary lattice state. This also resolves the lattice state
+ /// of both the operation results and any nested regions.
+ void visitRegionBranchOperation(RegionBranchOpInterface branch);
+
+ /// Visit the given terminator operation that exits a region under an
+ /// operation with control-flow semantics. These are terminators with no CFG
+ /// successors.
+ void visitRegionTerminator(Operation *op, RegionBranchOpInterface branch);
+
+ /// Visit the given terminator operation that exits a callable region. These
+ /// are terminators with no CFG successors.
+ void visitCallableTerminator(Operation *op, CallableOpInterface callable);
+
+ /// Mark the edge between `from` and `to` as executable.
+ void markEdgeLive(Block *from, Block *to);
+
+ /// Mark the entry blocks of the operation as executable.
+ void markEntryBlocksLive(Operation *op);
+
+ /// Get the constant values of the operands of the operation. Returns none if
+ /// any of the operand lattices are uninitialized.
+ Optional<SmallVector<Attribute>> getOperandValues(Operation *op);
+
+ /// A symbol table used for O(1) symbol lookups during simplification.
+ SymbolTableCollection symbolTable;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index fe5f3832322ce..662bc1084c5ff 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -7,6 +7,7 @@ set(LLVM_OPTIONAL_SOURCES
IntRangeAnalysis.cpp
Liveness.cpp
SliceAnalysis.cpp
+ SparseDataFlowAnalysis.cpp
AliasAnalysis/LocalAliasAnalysis.cpp
)
@@ -21,6 +22,7 @@ add_mlir_library(MLIRAnalysis
IntRangeAnalysis.cpp
Liveness.cpp
SliceAnalysis.cpp
+ SparseDataFlowAnalysis.cpp
AliasAnalysis/LocalAliasAnalysis.cpp
diff --git a/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
new file mode 100644
index 0000000000000..41ff8381b0749
--- /dev/null
+++ b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
@@ -0,0 +1,413 @@
+//===- SparseDataFlowAnalysis.cpp - Sparse data-flow analysis -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SparseDataFlowAnalysis.h"
+
+#define DEBUG_TYPE "dataflow"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// AbstractSparseLattice
+//===----------------------------------------------------------------------===//
+
+void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
+ // Push all users of the value to the queue.
+ for (Operation *user : point.get<Value>().getUsers())
+ for (DataFlowAnalysis *analysis : useDefSubscribers)
+ solver->enqueue({user, analysis});
+}
+
+//===----------------------------------------------------------------------===//
+// Executable
+//===----------------------------------------------------------------------===//
+
+ChangeResult Executable::setToLive() {
+ if (live)
+ return ChangeResult::NoChange;
+ live = true;
+ return ChangeResult::Change;
+}
+
+void Executable::print(raw_ostream &os) const {
+ os << (live ? "live" : "dead");
+}
+
+void Executable::onUpdate(DataFlowSolver *solver) const {
+ if (auto *block = point.dyn_cast<Block *>()) {
+ // Re-invoke the analyses on the block itself.
+ for (DataFlowAnalysis *analysis : subscribers)
+ solver->enqueue({block, analysis});
+ // Re-invoke the analyses on all operations in the block.
+ for (DataFlowAnalysis *analysis : subscribers)
+ for (Operation &op : *block)
+ solver->enqueue({&op, analysis});
+ } else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
+ // Re-invoke the analysis on the successor block.
+ if (auto *edge = dyn_cast<CFGEdge>(programPoint))
+ for (DataFlowAnalysis *analysis : subscribers)
+ solver->enqueue({edge->getTo(), analysis});
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// ConstantValue
+//===----------------------------------------------------------------------===//
+
+void ConstantValue::print(raw_ostream &os) const {
+ if (constant)
+ return constant.print(os);
+ os << "<NO VALUE>";
+}
+
+//===----------------------------------------------------------------------===//
+// PredecessorState
+//===----------------------------------------------------------------------===//
+
+void PredecessorState::print(raw_ostream &os) const {
+ if (allPredecessorsKnown())
+ os << "(all) ";
+ os << "predecessors:\n";
+ for (Operation *op : getKnownPredecessors())
+ os << " " << *op << "\n";
+}
+
+//===----------------------------------------------------------------------===//
+// CFGEdge
+//===----------------------------------------------------------------------===//
+
+Location CFGEdge::getLoc() const {
+ return FusedLoc::get(
+ getFrom()->getParent()->getContext(),
+ {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()});
+}
+
+void CFGEdge::print(raw_ostream &os) const {
+ getFrom()->print(os);
+ os << "\n -> \n";
+ getTo()->print(os);
+}
+
+//===----------------------------------------------------------------------===//
+// DeadCodeAnalysis
+//===----------------------------------------------------------------------===//
+
+DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
+ : DataFlowAnalysis(solver) {
+ registerPointKind<CFGEdge>();
+}
+
+LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
+ // Mark the top-level blocks as executable.
+ for (Region ®ion : top->getRegions()) {
+ if (region.empty())
+ continue;
+ auto *state = getOrCreate<Executable>(®ion.front());
+ propagateIfChanged(state, state->setToLive());
+ }
+
+ // Mark as overdefined the predecessors of symbol callables with potentially
+ // unknown predecessors.
+ initializeSymbolCallables(top);
+
+ return initializeRecursively(top);
+}
+
+void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
+ auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
+ Region &symbolTableRegion = symTable->getRegion(0);
+ Block *symbolTableBlock = &symbolTableRegion.front();
+
+ bool foundSymbolCallable = false;
+ for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
+ Region *callableRegion = callable.getCallableRegion();
+ if (!callableRegion)
+ continue;
+ auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
+ if (!symbol)
+ continue;
+
+ // Public symbol callables or those for which we can't see all uses have
+ // potentially unknown callsites.
+ if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
+ auto *state = getOrCreate<PredecessorState>(callable);
+ propagateIfChanged(state, state->setHasUnknownPredecessors());
+ }
+ foundSymbolCallable = true;
+ }
+
+ // Exit early if no eligible symbol callables were found in the table.
+ if (!foundSymbolCallable)
+ return;
+
+ // Walk the symbol table to check for non-call uses of symbols.
+ Optional<SymbolTable::UseRange> uses =
+ SymbolTable::getSymbolUses(&symbolTableRegion);
+ if (!uses) {
+ // If we couldn't gather the symbol uses, conservatively assume that
+ // we can't track information for any nested symbols.
+ return top->walk([&](CallableOpInterface callable) {
+ auto *state = getOrCreate<PredecessorState>(callable);
+ propagateIfChanged(state, state->setHasUnknownPredecessors());
+ });
+ }
+
+ for (const SymbolTable::SymbolUse &use : *uses) {
+ if (isa<CallOpInterface>(use.getUser()))
+ continue;
+ // If a callable symbol has a non-call use, then we can't be guaranteed to
+ // know all callsites.
+ Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
+ auto *state = getOrCreate<PredecessorState>(symbol);
+ propagateIfChanged(state, state->setHasUnknownPredecessors());
+ }
+ };
+ SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
+ walkFn);
+}
+
+LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
+ // Initialize the analysis by visiting every op with control-flow semantics.
+ if (op->getNumRegions() || op->getNumSuccessors() ||
+ op->hasTrait<OpTrait::IsTerminator>() || isa<CallOpInterface>(op)) {
+ // When the liveness of the parent block changes, make sure to re-invoke the
+ // analysis on the op.
+ if (op->getBlock())
+ getOrCreate<Executable>(op->getBlock())->blockContentSubscribe(this);
+ // Visit the op.
+ if (failed(visit(op)))
+ return failure();
+ }
+ // Recurse on nested operations.
+ for (Region ®ion : op->getRegions())
+ for (Operation &op : region.getOps())
+ if (failed(initializeRecursively(&op)))
+ return failure();
+ return success();
+}
+
+void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
+ auto *state = getOrCreate<Executable>(to);
+ propagateIfChanged(state, state->setToLive());
+ auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
+ propagateIfChanged(edgeState, edgeState->setToLive());
+}
+
+void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
+ for (Region ®ion : op->getRegions()) {
+ if (region.empty())
+ continue;
+ auto *state = getOrCreate<Executable>(®ion.front());
+ propagateIfChanged(state, state->setToLive());
+ }
+}
+
+LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
+ if (point.is<Block *>())
+ return success();
+ auto *op = point.dyn_cast<Operation *>();
+ if (!op)
+ return emitError(point.getLoc(), "unknown program point kind");
+
+ // If the parent block is not executable, there is nothing to do.
+ if (!getOrCreate<Executable>(op->getBlock())->isLive())
+ return success();
+
+ // We have a live call op. Add this as a live predecessor of the callee.
+ if (auto call = dyn_cast<CallOpInterface>(op))
+ visitCallOperation(call);
+
+ // Visit the regions.
+ if (op->getNumRegions()) {
+ // Check if we can reason about the region control-flow.
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+ visitRegionBranchOperation(branch);
+
+ // Check if this is a callable operation.
+ } else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
+ const auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
+
+ // If the callsites could not be resolved or are known to be non-empty,
+ // mark the callable as executable.
+ if (!callsites->allPredecessorsKnown() ||
+ !callsites->getKnownPredecessors().empty())
+ markEntryBlocksLive(callable);
+
+ // Otherwise, conservatively mark all entry blocks as executable.
+ } else {
+ markEntryBlocksLive(op);
+ }
+ }
+
+ if (op->hasTrait<OpTrait::IsTerminator>() && !op->getNumSuccessors()) {
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
+ // Visit the exiting terminator of a region.
+ visitRegionTerminator(op, branch);
+ } else if (auto callable =
+ dyn_cast<CallableOpInterface>(op->getParentOp())) {
+ // Visit the exiting terminator of a callable.
+ visitCallableTerminator(op, callable);
+ }
+ }
+ // Visit the successors.
+ if (op->getNumSuccessors()) {
+ // Check if we can reason about the control-flow.
+ if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+ visitBranchOperation(branch);
+
+ // Otherwise, conservatively mark all successors as exectuable.
+ } else {
+ for (Block *successor : op->getSuccessors())
+ markEdgeLive(op->getBlock(), successor);
+ }
+ }
+
+ return success();
+}
+
+void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
+ Operation *callableOp = nullptr;
+ if (Value callableValue = call.getCallableForCallee().dyn_cast<Value>())
+ callableOp = callableValue.getDefiningOp();
+ else
+ callableOp = call.resolveCallable(&symbolTable);
+
+ // A call to a externally-defined callable has unknown predecessors.
+ const auto isExternalCallable = [](Operation *op) {
+ if (auto callable = dyn_cast<CallableOpInterface>(op))
+ return !callable.getCallableRegion();
+ return false;
+ };
+
+ // TODO: Add support for non-symbol callables when necessary. If the
+ // callable has non-call uses we would mark as having reached pessimistic
+ // fixpoint, otherwise allow for propagating the return values out.
+ if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
+ !isExternalCallable(callableOp)) {
+ // Add the live callsite.
+ auto *callsites = getOrCreate<PredecessorState>(callableOp);
+ propagateIfChanged(callsites, callsites->join(call));
+ } else {
+ // Mark this call op's predecessors as overdefined.
+ auto *predecessors = getOrCreate<PredecessorState>(call);
+ propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
+ }
+}
+
+/// Get the constant values of the operands of an operation. If any of the
+/// constant value lattices are uninitialized, return none to indicate the
+/// analysis should bail out.
+static Optional<SmallVector<Attribute>> getOperandValuesImpl(
+ Operation *op,
+ function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
+ SmallVector<Attribute> operands;
+ operands.reserve(op->getNumOperands());
+ for (Value operand : op->getOperands()) {
+ const Lattice<ConstantValue> *cv = getLattice(operand);
+ // If any of the operands' values are uninitialized, bail out.
+ if (cv->isUninitialized())
+ return {};
+ operands.push_back(cv->getValue().getConstantValue());
+ }
+ return operands;
+}
+
+Optional<SmallVector<Attribute>>
+DeadCodeAnalysis::getOperandValues(Operation *op) {
+ return getOperandValuesImpl(op, [&](Value value) {
+ auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
+ lattice->useDefSubscribe(this);
+ return lattice;
+ });
+}
+
+void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
+ // Try to deduce a single successor for the branch.
+ Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
+ if (!operands)
+ return;
+
+ if (Block *successor = branch.getSuccessorForOperands(*operands)) {
+ markEdgeLive(branch->getBlock(), successor);
+ } else {
+ // Otherwise, mark all successors as executable and outgoing edges.
+ for (Block *successor : branch->getSuccessors())
+ markEdgeLive(branch->getBlock(), successor);
+ }
+}
+
+void DeadCodeAnalysis::visitRegionBranchOperation(
+ RegionBranchOpInterface branch) {
+ // Try to deduce which regions are executable.
+ Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
+ if (!operands)
+ return;
+
+ SmallVector<RegionSuccessor> successors;
+ branch.getSuccessorRegions(/*index=*/{}, *operands, successors);
+
+ for (const RegionSuccessor &successor : successors) {
+ // Mark the entry block as executable.
+ Region *region = successor.getSuccessor();
+ assert(region && "expected a region successor");
+ auto *state = getOrCreate<Executable>(®ion->front());
+ propagateIfChanged(state, state->setToLive());
+ // Add the parent op as a predecessor.
+ auto *predecessors = getOrCreate<PredecessorState>(®ion->front());
+ propagateIfChanged(predecessors, predecessors->join(branch));
+ }
+}
+
+void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
+ RegionBranchOpInterface branch) {
+ Optional<SmallVector<Attribute>> operands = getOperandValues(op);
+ if (!operands)
+ return;
+
+ SmallVector<RegionSuccessor> successors;
+ branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
+ *operands, successors);
+
+ // Mark successor region entry blocks as executable and add this op to the
+ // list of predecessors.
+ for (const RegionSuccessor &successor : successors) {
+ PredecessorState *predecessors;
+ if (Region *region = successor.getSuccessor()) {
+ auto *state = getOrCreate<Executable>(®ion->front());
+ propagateIfChanged(state, state->setToLive());
+ predecessors = getOrCreate<PredecessorState>(®ion->front());
+ } else {
+ // Add this terminator as a predecessor to the parent op.
+ predecessors = getOrCreate<PredecessorState>(branch);
+ }
+ propagateIfChanged(predecessors, predecessors->join(op));
+ }
+}
+
+void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
+ CallableOpInterface callable) {
+ // If there are no exiting values, we have nothing to do.
+ if (op->getNumOperands() == 0)
+ return;
+
+ // Add as predecessors to all callsites this return op.
+ auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
+ bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
+ for (Operation *predecessor : callsites->getKnownPredecessors()) {
+ assert(isa<CallOpInterface>(predecessor));
+ auto *predecessors = getOrCreate<PredecessorState>(predecessor);
+ if (canResolve) {
+ propagateIfChanged(predecessors, predecessors->join(op));
+ } else {
+ // If the terminator is not a return-like, then conservatively assume we
+ // can't resolve the predecessor.
+ propagateIfChanged(predecessors,
+ predecessors->setHasUnknownPredecessors());
+ }
+ }
+}
diff --git a/mlir/test/Analysis/test-dead-code-analysis.mlir b/mlir/test/Analysis/test-dead-code-analysis.mlir
new file mode 100644
index 0000000000000..2668b8349d1ca
--- /dev/null
+++ b/mlir/test/Analysis/test-dead-code-analysis.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt -test-dead-code-analysis 2>&1 %s | FileCheck %s
+
+// CHECK: test_cfg:
+// CHECK: region #0
+// CHECK: ^bb0 = live
+// CHECK: ^bb1 = live
+// CHECK: from ^bb1 = live
+// CHECK: from ^bb0 = live
+// CHECK: ^bb2 = live
+// CHECK: from ^bb1 = live
+func.func @test_cfg(%cond: i1) -> ()
+ attributes {tag = "test_cfg"} {
+ cf.br ^bb1
+
+^bb1:
+ cf.cond_br %cond, ^bb1, ^bb2
+
+^bb2:
+ return
+}
+
+func.func @test_region_control_flow(%cond: i1, %arg0: i64, %arg1: i64) -> () {
+ // CHECK: test_if:
+ // CHECK: region #0
+ // CHECK: region_preds: (all) predecessors:
+ // CHECK: scf.if
+ // CHECK: region #1
+ // CHECK: region_preds: (all) predecessors:
+ // CHECK: scf.if
+ // CHECK: op_preds: (all) predecessors:
+ // CHECK: scf.yield {then}
+ // CHECK: scf.yield {else}
+ scf.if %cond {
+ scf.yield {then}
+ } else {
+ scf.yield {else}
+ } {tag = "test_if"}
+
+ // test_while:
+ // region #0
+ // region_preds: (all) predecessors:
+ // scf.while
+ // scf.yield
+ // region #1
+ // region_preds: (all) predecessors:
+ // scf.condition
+ // op_preds: (all) predecessors:
+ // scf.condition
+ %c2_i64 = arith.constant 2 : i64
+ %0:2 = scf.while (%arg2 = %arg0) : (i64) -> (i64, i64) {
+ %1 = arith.cmpi slt, %arg2, %arg1 : i64
+ scf.condition(%1) %arg2, %arg2 : i64, i64
+ } do {
+ ^bb0(%arg2: i64, %arg3: i64):
+ %1 = arith.muli %arg3, %c2_i64 : i64
+ scf.yield %1 : i64
+ } attributes {tag = "test_while"}
+
+ return
+}
+
+// CHECK: foo:
+// CHECK: region #0
+// CHECK: ^bb0 = live
+// CHECK: op_preds: (all) predecessors:
+// CHECK: func.call @foo(%{{.*}}) {tag = "a"}
+// CHECK: func.call @foo(%{{.*}}) {tag = "b"}
+func.func private @foo(%arg0: i32) -> i32
+ attributes {tag = "foo"} {
+ return {a} %arg0 : i32
+}
+
+// CHECK: bar:
+// CHECK: region #0
+// CHECK: ^bb0 = live
+// CHECK: op_preds: predecessors:
+// CHECK: func.call @bar(%{{.*}}) {tag = "c"}
+func.func @bar(%cond: i1) -> i32
+ attributes {tag = "bar"} {
+ cf.cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+ %c0 = arith.constant 0 : i32
+ return {b} %c0 : i32
+
+^bb2:
+ %c1 = arith.constant 1 : i32
+ return {c} %c1 : i32
+}
+
+// CHECK: baz
+// CHECK: op_preds: (all) predecessors:
+func.func private @baz(i32) -> i32 attributes {tag = "baz"}
+
+func.func @test_callgraph(%cond: i1, %arg0: i32) -> i32 {
+ // CHECK: a:
+ // CHECK: op_preds: (all) predecessors:
+ // CHECK: func.return {a}
+ %0 = func.call @foo(%arg0) {tag = "a"} : (i32) -> i32
+ cf.cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+ // CHECK: b:
+ // CHECK: op_preds: (all) predecessors:
+ // CHECK: func.return {a}
+ %1 = func.call @foo(%arg0) {tag = "b"} : (i32) -> i32
+ return %1 : i32
+
+^bb2:
+ // CHECK: c:
+ // CHECK: op_preds: (all) predecessors:
+ // CHECK: func.return {b}
+ // CHECK: func.return {c}
+ %2 = func.call @bar(%cond) {tag = "c"} : (i1) -> i32
+ // CHECK: d:
+ // CHECK: op_preds: predecessors:
+ %3 = func.call @baz(%arg0) {tag = "d"} : (i32) -> i32
+ return %2 : i32
+}
+
+// CHECK: test_unknown_branch:
+// CHECK: region #0
+// CHECK: ^bb0 = live
+// CHECK: ^bb1 = live
+// CHECK: from ^bb0 = live
+// CHECK: ^bb2 = live
+// CHECK: from ^bb0 = live
+func.func @test_unknown_branch() -> ()
+ attributes {tag = "test_unknown_branch"} {
+ "test.unknown_br"() [^bb1, ^bb2] : () -> ()
+
+^bb1:
+ return
+
+^bb2:
+ return
+}
+
+// CHECK: test_unknown_region:
+// CHECK: region #0
+// CHECK: ^bb0 = live
+// CHECK: region #1
+// CHECK: ^bb0 = live
+func.func @test_unknown_region() -> () {
+ "test.unknown_region_br"() ({
+ ^bb0:
+ "test.unknown_region_end"() : () -> ()
+ }, {
+ ^bb0:
+ "test.unknown_region_end"() : () -> ()
+ }) {tag = "test_unknown_region"} : () -> ()
+ return
+}
+
+// CHECK: test_known_dead_block:
+// CHECK: region #0
+// CHECK: ^bb0 = live
+// CHECK: ^bb1 = live
+// CHECK: ^bb2 = dead
+func.func @test_known_dead_block() -> ()
+ attributes {tag = "test_known_dead_block"} {
+ %true = arith.constant true
+ cf.cond_br %true, ^bb1, ^bb2
+
+^bb1:
+ return
+
+^bb2:
+ return
+}
+
+// CHECK: test_known_dead_edge:
+// CHECK: ^bb2 = live
+// CHECK: from ^bb1 = dead
+// CHECK: from ^bb0 = live
+func.func @test_known_dead_edge(%arg0: i1) -> ()
+ attributes {tag = "test_known_dead_edge"} {
+ cf.cond_br %arg0, ^bb1, ^bb2
+
+^bb1:
+ %true = arith.constant true
+ cf.cond_br %true, ^bb3, ^bb2
+
+^bb2:
+ return
+
+^bb3:
+ return
+}
+
+func.func @test_known_region_predecessors() -> () {
+ %false = arith.constant false
+ // CHECK: test_known_if:
+ // CHECK: region #0
+ // CHECK: ^bb0 = dead
+ // CHECK: region #1
+ // CHECK: ^bb0 = live
+ // CHECK: region_preds: (all) predecessors:
+ // CHECK: scf.if
+ // CHECK: op_preds: (all) predecessors:
+ // CHECK: scf.yield {else}
+ scf.if %false {
+ scf.yield {then}
+ } else {
+ scf.yield {else}
+ } {tag = "test_known_if"}
+ return
+}
+
+// CHECK: callable:
+// CHECK: region #0
+// CHECK: ^bb0 = live
+// CHECK: op_preds: predecessors:
+// CHECK: func.call @callable() {then}
+func.func @callable() attributes {tag = "callable"} {
+ return
+}
+
+func.func @test_dead_callsite() -> () {
+ %true = arith.constant true
+ scf.if %true {
+ func.call @callable() {then} : () -> ()
+ scf.yield
+ } else {
+ func.call @callable() {else} : () -> ()
+ scf.yield
+ }
+ return
+}
+
+func.func private @test_dead_return(%arg0: i32) -> i32 {
+ %true = arith.constant true
+ cf.cond_br %true, ^bb1, ^bb1
+
+^bb1:
+ return {true} %arg0 : i32
+
+^bb2:
+ return {false} %arg0 : i32
+}
+
+func.func @test_call_dead_return(%arg0: i32) -> () {
+ // CHECK: test_dead_return:
+ // CHECK: op_preds: (all) predecessors:
+ // CHECK: func.return {true}
+ %0 = func.call @test_dead_return(%arg0) {tag = "test_dead_return"} : (i32) -> i32
+ return
+}
diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index d0b9d2be4f6ea..128615659cf30 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_library(MLIRTestAnalysis
TestCallGraph.cpp
TestDataFlow.cpp
TestDataFlowFramework.cpp
+ TestDeadCodeAnalysis.cpp
TestLiveness.cpp
TestMatchReduction.cpp
TestMemRefBoundCheck.cpp
diff --git a/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp
new file mode 100644
index 0000000000000..a7c33526c80de
--- /dev/null
+++ b/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp
@@ -0,0 +1,116 @@
+//===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SparseDataFlowAnalysis.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+/// Print the liveness of every block, control-flow edge, and the predecessors
+/// of all regions, callables, and calls.
+static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
+ raw_ostream &os) {
+ op->walk([&](Operation *op) {
+ auto tag = op->getAttrOfType<StringAttr>("tag");
+ if (!tag)
+ return;
+ os << tag.getValue() << ":\n";
+ for (Region ®ion : op->getRegions()) {
+ os << " region #" << region.getRegionNumber() << "\n";
+ for (Block &block : region) {
+ os << " ";
+ block.printAsOperand(os);
+ os << " = ";
+ auto *live = solver.lookupState<Executable>(&block);
+ if (live)
+ os << *live;
+ else
+ os << "dead";
+ os << "\n";
+ for (Block *pred : block.getPredecessors()) {
+ os << " from ";
+ pred->printAsOperand(os);
+ os << " = ";
+ auto *live = solver.lookupState<Executable>(
+ solver.getProgramPoint<CFGEdge>(pred, &block));
+ if (live)
+ os << *live;
+ else
+ os << "dead";
+ os << "\n";
+ }
+ }
+ if (!region.empty()) {
+ auto *preds = solver.lookupState<PredecessorState>(®ion.front());
+ if (preds)
+ os << "region_preds: " << *preds << "\n";
+ }
+ }
+ auto *preds = solver.lookupState<PredecessorState>(op);
+ if (preds)
+ os << "op_preds: " << *preds << "\n";
+ });
+}
+
+namespace {
+/// This is a simple analysis that implements a transfer function for constant
+/// operations.
+struct ConstantAnalysis : public DataFlowAnalysis {
+ using DataFlowAnalysis::DataFlowAnalysis;
+
+ LogicalResult initialize(Operation *top) override {
+ WalkResult result = top->walk([&](Operation *op) {
+ if (op->hasTrait<OpTrait::ConstantLike>())
+ if (failed(visit(op)))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ return success(!result.wasInterrupted());
+ }
+
+ LogicalResult visit(ProgramPoint point) override {
+ Operation *op = point.get<Operation *>();
+ Attribute value;
+ if (matchPattern(op, m_Constant(&value))) {
+ auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
+ propagateIfChanged(
+ constant, constant->join(ConstantValue(value, op->getDialect())));
+ }
+ return success();
+ }
+};
+
+/// This is a simple pass that runs dead code analysis with no constant value
+/// provider. It marks everything as live.
+struct TestDeadCodeAnalysisPass
+ : public PassWrapper<TestDeadCodeAnalysisPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass)
+
+ StringRef getArgument() const override { return "test-dead-code-analysis"; }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+
+ DataFlowSolver solver;
+ solver.load<DeadCodeAnalysis>();
+ solver.load<ConstantAnalysis>();
+ if (failed(solver.initializeAndRun(op)))
+ return signalPassFailure();
+ printAnalysisResults(solver, op, llvm::errs());
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestDeadCodeAnalysisPass() {
+ PassRegistration<TestDeadCodeAnalysisPass>();
+}
+} // end namespace test
+} // end namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 9e872ab63f5f4..ecd85a7f3aaa4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -72,6 +72,7 @@ void registerTestGpuSerializeToCubinPass();
void registerTestGpuSerializeToHsacoPass();
void registerTestDataFlowPass();
void registerTestDataLayoutQuery();
+void registerTestDeadCodeAnalysisPass();
void registerTestDecomposeCallGraphTypes();
void registerTestDiagnosticsPass();
void registerTestDominancePass();
@@ -173,6 +174,7 @@ void registerTestPasses() {
mlir::test::registerTestDecomposeCallGraphTypes();
mlir::test::registerTestDataFlowPass();
mlir::test::registerTestDataLayoutQuery();
+ mlir::test::registerTestDeadCodeAnalysisPass();
mlir::test::registerTestDominancePass();
mlir::test::registerTestDynamicPipelinePass();
mlir::test::registerTestExpandMathPass();
More information about the llvm-branch-commits
mailing list