[llvm-branch-commits] [mlir] 08316f8 - [mlir] Add Dead Code Analysis

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jun 29 11:00:30 PDT 2022


Author: Mogball
Date: 2022-06-29T10:59:44-07:00
New Revision: 08316f82e5537e4ba8eff1d82ac4eb3b88e20307

URL: https://github.com/llvm/llvm-project/commit/08316f82e5537e4ba8eff1d82ac4eb3b88e20307
DIFF: https://github.com/llvm/llvm-project/commit/08316f82e5537e4ba8eff1d82ac4eb3b88e20307.diff

LOG: [mlir] Add Dead Code Analysis

This patch implements the analysis state classes needed for sparse data-flow analysis and implements a dead-code analysis using those states to determine liveness of blocks, control-flow edges, region predecessors, and function callsites.

Added: 
    mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
    mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
    mlir/test/Analysis/test-dead-code-analysis.mlir
    mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp

Modified: 
    mlir/include/mlir/Analysis/DataFlowAnalysis.h
    mlir/include/mlir/Analysis/DataFlowFramework.h
    mlir/lib/Analysis/CMakeLists.txt
    mlir/test/lib/Analysis/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlowAnalysis.h b/mlir/include/mlir/Analysis/DataFlowAnalysis.h
index 883624dfa4805..bc7664b6a20ee 100644
--- a/mlir/include/mlir/Analysis/DataFlowAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlowAnalysis.h
@@ -22,34 +22,17 @@
 #ifndef MLIR_ANALYSIS_DATAFLOWANALYSIS_H
 #define MLIR_ANALYSIS_DATAFLOWANALYSIS_H
 
+#include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/Support/Allocator.h"
 
-namespace mlir {
-//===----------------------------------------------------------------------===//
-// ChangeResult
-//===----------------------------------------------------------------------===//
-
-/// A result type used to indicate if a change happened. Boolean operations on
-/// ChangeResult behave as though `Change` is truthy.
-enum class ChangeResult {
-  NoChange,
-  Change,
-};
-inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) {
-  return lhs == ChangeResult::Change ? lhs : rhs;
-}
-inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) {
-  lhs = lhs | rhs;
-  return lhs;
-}
-inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
-  return lhs == ChangeResult::NoChange ? lhs : rhs;
-}
+/// TODO: Remove this file when SCCP and integer range analysis have been ported
+/// to the new framework.
 
+namespace mlir {
 //===----------------------------------------------------------------------===//
 // AbstractLatticeElement
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index 9b85182532232..19d8fc0c3e19b 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -16,7 +16,6 @@
 #ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
 #define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H
 
-#include "mlir/Analysis/DataFlowAnalysis.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/Support/StorageUniquer.h"
 #include "llvm/ADT/SetVector.h"
@@ -25,6 +24,27 @@
 
 namespace mlir {
 
+//===----------------------------------------------------------------------===//
+// ChangeResult
+//===----------------------------------------------------------------------===//
+
+/// A result type used to indicate if a change happened. Boolean operations on
+/// ChangeResult behave as though `Change` is truthy.
+enum class ChangeResult {
+  NoChange,
+  Change,
+};
+inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) {
+  return lhs == ChangeResult::Change ? lhs : rhs;
+}
+inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) {
+  lhs = lhs | rhs;
+  return lhs;
+}
+inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
+  return lhs == ChangeResult::NoChange ? lhs : rhs;
+}
+
 /// Forward declare the analysis state class.
 class AnalysisState;
 
@@ -137,6 +157,12 @@ struct ProgramPoint
   using ParentTy::PointerUnion;
   /// Allow implicit conversion from the parent type.
   ProgramPoint(ParentTy point = nullptr) : ParentTy(point) {}
+  /// Allow implicit conversions from operation wrappers.
+  /// TODO: For Windows only. Find a better solution.
+  template <typename OpT, typename = typename std::enable_if_t<
+                              std::is_convertible<OpT, Operation *>::value &&
+                              !std::is_same<OpT, Operation *>::value>>
+  ProgramPoint(OpT op) : ParentTy(op) {}
 
   /// Print the program point.
   void print(raw_ostream &os) const;
@@ -180,7 +206,7 @@ class DataFlowSolver {
   /// does not exist.
   template <typename StateT, typename PointT>
   const StateT *lookupState(PointT point) const {
-    auto it = analysisStates.find({point, TypeID::get<StateT>()});
+    auto it = analysisStates.find({ProgramPoint(point), TypeID::get<StateT>()});
     if (it == analysisStates.end())
       return nullptr;
     return static_cast<const StateT *>(it->second.get());

diff  --git a/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
new file mode 100644
index 0000000000000..ee9392b387e60
--- /dev/null
+++ b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
@@ -0,0 +1,418 @@
+//===- SparseDataFlowAnalysis.h - Sparse data-flow analysis ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements sparse data-flow analysis using the data-flow analysis
+// framework. The analysis is forward and conditional and uses the results of
+// dead code analysis to prune dead code during the analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H
+#define MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H
+
+#include "mlir/Analysis/DataFlowFramework.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// AbstractSparseLattice
+//===----------------------------------------------------------------------===//
+
+/// This class represents an abstract lattice. A lattice contains information
+/// about an SSA value and is what's propagated across the IR by sparse
+/// data-flow analysis.
+class AbstractSparseLattice : public AnalysisState {
+public:
+  /// Lattices can only be created for values.
+  AbstractSparseLattice(Value value) : AnalysisState(value) {}
+
+  /// Join the information contained in 'rhs' into this lattice. Returns
+  /// if the value of the lattice changed.
+  virtual ChangeResult join(const AbstractSparseLattice &rhs) = 0;
+
+  /// Returns true if the lattice element is at fixpoint and further calls to
+  /// `join` will not update the value of the element.
+  virtual bool isAtFixpoint() const = 0;
+
+  /// Mark the lattice element as having reached a pessimistic fixpoint. This
+  /// means that the lattice may potentially have conflicting value states, and
+  /// only the most conservative value should be relied on.
+  virtual ChangeResult markPessimisticFixpoint() = 0;
+
+  /// When the lattice gets updated, propagate an update to users of the value
+  /// using its use-def chain to subscribed analyses.
+  void onUpdate(DataFlowSolver *solver) const override;
+
+  /// Subscribe an analysis to updates of the lattice. When the lattice changes,
+  /// subscribed analyses are re-invoked on all users of the value. This is
+  /// more efficient than relying on the dependency map.
+  void useDefSubscribe(DataFlowAnalysis *analysis) {
+    useDefSubscribers.insert(analysis);
+  }
+
+private:
+  /// A set of analyses that should be updated when this lattice changes.
+  SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
+            SmallPtrSet<DataFlowAnalysis *, 4>>
+      useDefSubscribers;
+};
+
+//===----------------------------------------------------------------------===//
+// Lattice
+//===----------------------------------------------------------------------===//
+
+/// This class represents a lattice holding a specific value of type `ValueT`.
+/// Lattice values (`ValueT`) are required to adhere to the following:
+///
+///   * static ValueT join(const ValueT &lhs, const ValueT &rhs);
+///     - This method conservatively joins the information held by `lhs`
+///       and `rhs` into a new value. This method is required to be monotonic.
+///   * bool operator==(const ValueT &rhs) const;
+///
+template <typename ValueT>
+class Lattice : public AbstractSparseLattice {
+public:
+  using AbstractSparseLattice::AbstractSparseLattice;
+
+  /// Get a lattice element with a known value.
+  Lattice(const ValueT &knownValue = ValueT())
+      : AbstractSparseLattice(Value()), knownValue(knownValue) {}
+
+  /// Return the value held by this lattice. This requires that the value is
+  /// initialized.
+  ValueT &getValue() {
+    assert(!isUninitialized() && "expected known lattice element");
+    return *optimisticValue;
+  }
+  const ValueT &getValue() const {
+    return const_cast<Lattice<ValueT> *>(this)->getValue();
+  }
+
+  /// Returns true if the value of this lattice hasn't yet been initialized.
+  bool isUninitialized() const override { return !optimisticValue.hasValue(); }
+  /// Force the initialization of the element by setting it to its pessimistic
+  /// fixpoint.
+  ChangeResult defaultInitialize() override {
+    return markPessimisticFixpoint();
+  }
+
+  /// Returns true if the lattice has reached a fixpoint. A fixpoint is when
+  /// the information optimistically assumed to be true is the same as the
+  /// information known to be true.
+  bool isAtFixpoint() const override { return optimisticValue == knownValue; }
+
+  /// Join the information contained in the 'rhs' lattice into this
+  /// lattice. Returns if the state of the current lattice changed.
+  ChangeResult join(const AbstractSparseLattice &rhs) override {
+    const Lattice<ValueT> &rhsLattice =
+        static_cast<const Lattice<ValueT> &>(rhs);
+
+    // If we are at a fixpoint, or rhs is uninitialized, there is nothing to do.
+    if (isAtFixpoint() || rhsLattice.isUninitialized())
+      return ChangeResult::NoChange;
+
+    // Join the rhs value into this lattice.
+    return join(rhsLattice.getValue());
+  }
+
+  /// Join the information contained in the 'rhs' value into this
+  /// lattice. Returns if the state of the current lattice changed.
+  ChangeResult join(const ValueT &rhs) {
+    // If the current lattice is uninitialized, copy the rhs value.
+    if (isUninitialized()) {
+      optimisticValue = rhs;
+      return ChangeResult::Change;
+    }
+
+    // Otherwise, join rhs with the current optimistic value.
+    ValueT newValue = ValueT::join(*optimisticValue, rhs);
+    assert(ValueT::join(newValue, *optimisticValue) == newValue &&
+           "expected `join` to be monotonic");
+    assert(ValueT::join(newValue, rhs) == newValue &&
+           "expected `join` to be monotonic");
+
+    // Update the current optimistic value if something changed.
+    if (newValue == optimisticValue)
+      return ChangeResult::NoChange;
+
+    optimisticValue = newValue;
+    return ChangeResult::Change;
+  }
+
+  /// Mark the lattice element as having reached a pessimistic fixpoint. This
+  /// means that the lattice may potentially have conflicting value states,
+  /// and only the conservatively known value state should be relied on.
+  ChangeResult markPessimisticFixpoint() override {
+    if (isAtFixpoint())
+      return ChangeResult::NoChange;
+
+    // For this fixed point, we take whatever we knew to be true and set that
+    // to our optimistic value.
+    optimisticValue = knownValue;
+    return ChangeResult::Change;
+  }
+
+  /// Print the lattice element.
+  void print(raw_ostream &os) const override {
+    os << "[";
+    knownValue.print(os);
+    os << ", ";
+    if (optimisticValue) {
+      optimisticValue->print(os);
+    } else {
+      os << "<NULL>";
+    }
+    os << "]";
+  }
+
+private:
+  /// The value that is conservatively known to be true.
+  ValueT knownValue;
+  /// The currently computed value that is optimistically assumed to be true,
+  /// or None if the lattice element is uninitialized.
+  Optional<ValueT> optimisticValue;
+};
+
+//===----------------------------------------------------------------------===//
+// Executable
+//===----------------------------------------------------------------------===//
+
+/// This is a simple analysis state that represents whether the associated
+/// program point (either a block or a control-flow edge) is live.
+class Executable : public AnalysisState {
+public:
+  using AnalysisState::AnalysisState;
+
+  /// The state is initialized by default.
+  bool isUninitialized() const override { return false; }
+
+  /// The state is always initialized.
+  ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
+
+  /// Set the state of the program point to live.
+  ChangeResult setToLive();
+
+  /// Get whether the program point is live.
+  bool isLive() const { return live; }
+
+  /// Print the liveness;
+  void print(raw_ostream &os) const override;
+
+  /// When the state of the program point is changed to live, re-invoke
+  /// subscribed analyses on the operations in the block and on the block
+  /// itself.
+  void onUpdate(DataFlowSolver *solver) const override;
+
+  /// Subscribe an analysis to changes to the liveness.
+  void blockContentSubscribe(DataFlowAnalysis *analysis) {
+    subscribers.insert(analysis);
+  }
+
+private:
+  /// Whether the program point is live. Optimistically assume that the program
+  /// point is dead.
+  bool live = false;
+
+  /// A set of analyses that should be updated when this state changes.
+  SetVector<DataFlowAnalysis *, SmallVector<DataFlowAnalysis *, 4>,
+            SmallPtrSet<DataFlowAnalysis *, 4>>
+      subscribers;
+};
+
+//===----------------------------------------------------------------------===//
+// ConstantValue
+//===----------------------------------------------------------------------===//
+
+/// This lattice value represents a known constant value of a lattice.
+class ConstantValue {
+public:
+  /// Construct a constant value with a known constant.
+  ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr)
+      : constant(knownValue), dialect(dialect) {}
+
+  /// Get the constant value. Returns null if no value was determined.
+  Attribute getConstantValue() const { return constant; }
+
+  /// Get the dialect instance that can be used to materialize the constant.
+  Dialect *getConstantDialect() const { return dialect; }
+
+  /// Compare the constant values.
+  bool operator==(const ConstantValue &rhs) const {
+    return constant == rhs.constant;
+  }
+
+  /// The union with another constant value is null if they are 
diff erent, and
+  /// the same if they are the same.
+  static ConstantValue join(const ConstantValue &lhs,
+                            const ConstantValue &rhs) {
+    return lhs == rhs ? lhs : ConstantValue();
+  }
+
+  /// Print the constant value.
+  void print(raw_ostream &os) const;
+
+private:
+  /// The constant value.
+  Attribute constant;
+  /// An dialect instance that can be used to materialize the constant.
+  Dialect *dialect;
+};
+
+//===----------------------------------------------------------------------===//
+// PredecessorState
+//===----------------------------------------------------------------------===//
+
+/// This analysis state represents a set of known predecessors. This state is
+/// used in sparse data-flow analysis to reason about region control-flow and
+/// callgraphs. The state may also indicate that not all predecessors can be
+/// known, if for example not all callsites of a callable are visible.
+class PredecessorState : public AnalysisState {
+public:
+  using AnalysisState::AnalysisState;
+
+  /// The state is initialized by default.
+  bool isUninitialized() const override { return false; }
+
+  /// The state is always initialized.
+  ChangeResult defaultInitialize() override { return ChangeResult::NoChange; }
+
+  /// Print the known predecessors.
+  void print(raw_ostream &os) const override;
+
+  /// Returns true if all predecessors are known.
+  bool allPredecessorsKnown() const { return allKnown; }
+
+  /// Indicate that there are potentially unknown predecessors.
+  ChangeResult setHasUnknownPredecessors() {
+    if (!allKnown)
+      return ChangeResult::NoChange;
+    allKnown = false;
+    return ChangeResult::Change;
+  }
+
+  /// Get the known predecessors.
+  ArrayRef<Operation *> getKnownPredecessors() const {
+    return knownPredecessors.getArrayRef();
+  }
+
+  /// Add a known predecessor.
+  ChangeResult join(Operation *predecessor) {
+    return knownPredecessors.insert(predecessor) ? ChangeResult::Change
+                                                 : ChangeResult::NoChange;
+  }
+
+private:
+  /// Whether all predecessors are known. Optimistically assume that we know
+  /// all predecessors.
+  bool allKnown = true;
+
+  /// The known control-flow predecessors of this program point.
+  SetVector<Operation *, SmallVector<Operation *, 4>,
+            SmallPtrSet<Operation *, 4>>
+      knownPredecessors;
+};
+
+//===----------------------------------------------------------------------===//
+// CFGEdge
+//===----------------------------------------------------------------------===//
+
+/// This program point represents a control-flow edge between a block and one
+/// of its successors.
+class CFGEdge
+    : public GenericProgramPointBase<CFGEdge, std::pair<Block *, Block *>> {
+public:
+  using Base::Base;
+
+  /// Get the block from which the edge originates.
+  Block *getFrom() const { return getValue().first; }
+  /// Get the target block.
+  Block *getTo() const { return getValue().second; }
+
+  /// Print the blocks between the control-flow edge.
+  void print(raw_ostream &os) const override;
+  /// Get a fused location of both blocks.
+  Location getLoc() const override;
+};
+
+//===----------------------------------------------------------------------===//
+// DeadCodeAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Dead code analysis analyzes control-flow, as understood by
+/// `RegionBranchOpInterface` and `BranchOpInterface`, and the callgraph, as
+/// understood by `CallableOpInterface` and `CallOpInterface`.
+///
+/// This analysis uses known constant values of operands to determine the
+/// liveness of each block and each edge between a block and its predecessors.
+/// For region control-flow, this analysis determines the predecessor operations
+/// for region entry blocks and region control-flow operations. For the
+/// callgraph, this analysis determines the callsites and live returns of every
+/// function.
+class DeadCodeAnalysis : public DataFlowAnalysis {
+public:
+  explicit DeadCodeAnalysis(DataFlowSolver &solver);
+
+  /// Initialize the analysis by visiting every operation with potential
+  /// control-flow semantics.
+  LogicalResult initialize(Operation *top) override;
+
+  /// Visit an operation with control-flow semantics and deduce which of its
+  /// successors are live.
+  LogicalResult visit(ProgramPoint point) override;
+
+private:
+  /// Find and mark symbol callables with potentially unknown callsites as
+  /// having overdefined predecessors. `top` is the top-level operation that the
+  /// analysis is operating on.
+  void initializeSymbolCallables(Operation *top);
+
+  /// Recursively Initialize the analysis on nested regions.
+  LogicalResult initializeRecursively(Operation *op);
+
+  /// Visit the given call operation and compute any necessary lattice state.
+  void visitCallOperation(CallOpInterface call);
+
+  /// Visit the given branch operation with successors and try to determine
+  /// which are live from the current block.
+  void visitBranchOperation(BranchOpInterface branch);
+
+  /// Visit the given region branch operation, which defines regions, and
+  /// compute any necessary lattice state. This also resolves the lattice state
+  /// of both the operation results and any nested regions.
+  void visitRegionBranchOperation(RegionBranchOpInterface branch);
+
+  /// Visit the given terminator operation that exits a region under an
+  /// operation with control-flow semantics. These are terminators with no CFG
+  /// successors.
+  void visitRegionTerminator(Operation *op, RegionBranchOpInterface branch);
+
+  /// Visit the given terminator operation that exits a callable region. These
+  /// are terminators with no CFG successors.
+  void visitCallableTerminator(Operation *op, CallableOpInterface callable);
+
+  /// Mark the edge between `from` and `to` as executable.
+  void markEdgeLive(Block *from, Block *to);
+
+  /// Mark the entry blocks of the operation as executable.
+  void markEntryBlocksLive(Operation *op);
+
+  /// Get the constant values of the operands of the operation. Returns none if
+  /// any of the operand lattices are uninitialized.
+  Optional<SmallVector<Attribute>> getOperandValues(Operation *op);
+
+  /// A symbol table used for O(1) symbol lookups during simplification.
+  SymbolTableCollection symbolTable;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H

diff  --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index fe5f3832322ce..662bc1084c5ff 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -7,6 +7,7 @@ set(LLVM_OPTIONAL_SOURCES
   IntRangeAnalysis.cpp
   Liveness.cpp
   SliceAnalysis.cpp
+  SparseDataFlowAnalysis.cpp
 
   AliasAnalysis/LocalAliasAnalysis.cpp
   )
@@ -21,6 +22,7 @@ add_mlir_library(MLIRAnalysis
   IntRangeAnalysis.cpp
   Liveness.cpp
   SliceAnalysis.cpp
+  SparseDataFlowAnalysis.cpp
 
   AliasAnalysis/LocalAliasAnalysis.cpp
 

diff  --git a/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
new file mode 100644
index 0000000000000..41ff8381b0749
--- /dev/null
+++ b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
@@ -0,0 +1,413 @@
+//===- SparseDataFlowAnalysis.cpp - Sparse data-flow analysis -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SparseDataFlowAnalysis.h"
+
+#define DEBUG_TYPE "dataflow"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// AbstractSparseLattice
+//===----------------------------------------------------------------------===//
+
+void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
+  // Push all users of the value to the queue.
+  for (Operation *user : point.get<Value>().getUsers())
+    for (DataFlowAnalysis *analysis : useDefSubscribers)
+      solver->enqueue({user, analysis});
+}
+
+//===----------------------------------------------------------------------===//
+// Executable
+//===----------------------------------------------------------------------===//
+
+ChangeResult Executable::setToLive() {
+  if (live)
+    return ChangeResult::NoChange;
+  live = true;
+  return ChangeResult::Change;
+}
+
+void Executable::print(raw_ostream &os) const {
+  os << (live ? "live" : "dead");
+}
+
+void Executable::onUpdate(DataFlowSolver *solver) const {
+  if (auto *block = point.dyn_cast<Block *>()) {
+    // Re-invoke the analyses on the block itself.
+    for (DataFlowAnalysis *analysis : subscribers)
+      solver->enqueue({block, analysis});
+    // Re-invoke the analyses on all operations in the block.
+    for (DataFlowAnalysis *analysis : subscribers)
+      for (Operation &op : *block)
+        solver->enqueue({&op, analysis});
+  } else if (auto *programPoint = point.dyn_cast<GenericProgramPoint *>()) {
+    // Re-invoke the analysis on the successor block.
+    if (auto *edge = dyn_cast<CFGEdge>(programPoint))
+      for (DataFlowAnalysis *analysis : subscribers)
+        solver->enqueue({edge->getTo(), analysis});
+  }
+}
+
+//===----------------------------------------------------------------------===//
+// ConstantValue
+//===----------------------------------------------------------------------===//
+
+void ConstantValue::print(raw_ostream &os) const {
+  if (constant)
+    return constant.print(os);
+  os << "<NO VALUE>";
+}
+
+//===----------------------------------------------------------------------===//
+// PredecessorState
+//===----------------------------------------------------------------------===//
+
+void PredecessorState::print(raw_ostream &os) const {
+  if (allPredecessorsKnown())
+    os << "(all) ";
+  os << "predecessors:\n";
+  for (Operation *op : getKnownPredecessors())
+    os << "  " << *op << "\n";
+}
+
+//===----------------------------------------------------------------------===//
+// CFGEdge
+//===----------------------------------------------------------------------===//
+
+Location CFGEdge::getLoc() const {
+  return FusedLoc::get(
+      getFrom()->getParent()->getContext(),
+      {getFrom()->getParent()->getLoc(), getTo()->getParent()->getLoc()});
+}
+
+void CFGEdge::print(raw_ostream &os) const {
+  getFrom()->print(os);
+  os << "\n -> \n";
+  getTo()->print(os);
+}
+
+//===----------------------------------------------------------------------===//
+// DeadCodeAnalysis
+//===----------------------------------------------------------------------===//
+
+DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver)
+    : DataFlowAnalysis(solver) {
+  registerPointKind<CFGEdge>();
+}
+
+LogicalResult DeadCodeAnalysis::initialize(Operation *top) {
+  // Mark the top-level blocks as executable.
+  for (Region &region : top->getRegions()) {
+    if (region.empty())
+      continue;
+    auto *state = getOrCreate<Executable>(&region.front());
+    propagateIfChanged(state, state->setToLive());
+  }
+
+  // Mark as overdefined the predecessors of symbol callables with potentially
+  // unknown predecessors.
+  initializeSymbolCallables(top);
+
+  return initializeRecursively(top);
+}
+
+void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) {
+  auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
+    Region &symbolTableRegion = symTable->getRegion(0);
+    Block *symbolTableBlock = &symbolTableRegion.front();
+
+    bool foundSymbolCallable = false;
+    for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
+      Region *callableRegion = callable.getCallableRegion();
+      if (!callableRegion)
+        continue;
+      auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
+      if (!symbol)
+        continue;
+
+      // Public symbol callables or those for which we can't see all uses have
+      // potentially unknown callsites.
+      if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
+        auto *state = getOrCreate<PredecessorState>(callable);
+        propagateIfChanged(state, state->setHasUnknownPredecessors());
+      }
+      foundSymbolCallable = true;
+    }
+
+    // Exit early if no eligible symbol callables were found in the table.
+    if (!foundSymbolCallable)
+      return;
+
+    // Walk the symbol table to check for non-call uses of symbols.
+    Optional<SymbolTable::UseRange> uses =
+        SymbolTable::getSymbolUses(&symbolTableRegion);
+    if (!uses) {
+      // If we couldn't gather the symbol uses, conservatively assume that
+      // we can't track information for any nested symbols.
+      return top->walk([&](CallableOpInterface callable) {
+        auto *state = getOrCreate<PredecessorState>(callable);
+        propagateIfChanged(state, state->setHasUnknownPredecessors());
+      });
+    }
+
+    for (const SymbolTable::SymbolUse &use : *uses) {
+      if (isa<CallOpInterface>(use.getUser()))
+        continue;
+      // If a callable symbol has a non-call use, then we can't be guaranteed to
+      // know all callsites.
+      Operation *symbol = symbolTable.lookupSymbolIn(top, use.getSymbolRef());
+      auto *state = getOrCreate<PredecessorState>(symbol);
+      propagateIfChanged(state, state->setHasUnknownPredecessors());
+    }
+  };
+  SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(),
+                                walkFn);
+}
+
+LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) {
+  // Initialize the analysis by visiting every op with control-flow semantics.
+  if (op->getNumRegions() || op->getNumSuccessors() ||
+      op->hasTrait<OpTrait::IsTerminator>() || isa<CallOpInterface>(op)) {
+    // When the liveness of the parent block changes, make sure to re-invoke the
+    // analysis on the op.
+    if (op->getBlock())
+      getOrCreate<Executable>(op->getBlock())->blockContentSubscribe(this);
+    // Visit the op.
+    if (failed(visit(op)))
+      return failure();
+  }
+  // Recurse on nested operations.
+  for (Region &region : op->getRegions())
+    for (Operation &op : region.getOps())
+      if (failed(initializeRecursively(&op)))
+        return failure();
+  return success();
+}
+
+void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) {
+  auto *state = getOrCreate<Executable>(to);
+  propagateIfChanged(state, state->setToLive());
+  auto *edgeState = getOrCreate<Executable>(getProgramPoint<CFGEdge>(from, to));
+  propagateIfChanged(edgeState, edgeState->setToLive());
+}
+
+void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) {
+  for (Region &region : op->getRegions()) {
+    if (region.empty())
+      continue;
+    auto *state = getOrCreate<Executable>(&region.front());
+    propagateIfChanged(state, state->setToLive());
+  }
+}
+
+LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) {
+  if (point.is<Block *>())
+    return success();
+  auto *op = point.dyn_cast<Operation *>();
+  if (!op)
+    return emitError(point.getLoc(), "unknown program point kind");
+
+  // If the parent block is not executable, there is nothing to do.
+  if (!getOrCreate<Executable>(op->getBlock())->isLive())
+    return success();
+
+  // We have a live call op. Add this as a live predecessor of the callee.
+  if (auto call = dyn_cast<CallOpInterface>(op))
+    visitCallOperation(call);
+
+  // Visit the regions.
+  if (op->getNumRegions()) {
+    // Check if we can reason about the region control-flow.
+    if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+      visitRegionBranchOperation(branch);
+
+      // Check if this is a callable operation.
+    } else if (auto callable = dyn_cast<CallableOpInterface>(op)) {
+      const auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
+
+      // If the callsites could not be resolved or are known to be non-empty,
+      // mark the callable as executable.
+      if (!callsites->allPredecessorsKnown() ||
+          !callsites->getKnownPredecessors().empty())
+        markEntryBlocksLive(callable);
+
+      // Otherwise, conservatively mark all entry blocks as executable.
+    } else {
+      markEntryBlocksLive(op);
+    }
+  }
+
+  if (op->hasTrait<OpTrait::IsTerminator>() && !op->getNumSuccessors()) {
+    if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
+      // Visit the exiting terminator of a region.
+      visitRegionTerminator(op, branch);
+    } else if (auto callable =
+                   dyn_cast<CallableOpInterface>(op->getParentOp())) {
+      // Visit the exiting terminator of a callable.
+      visitCallableTerminator(op, callable);
+    }
+  }
+  // Visit the successors.
+  if (op->getNumSuccessors()) {
+    // Check if we can reason about the control-flow.
+    if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+      visitBranchOperation(branch);
+
+      // Otherwise, conservatively mark all successors as exectuable.
+    } else {
+      for (Block *successor : op->getSuccessors())
+        markEdgeLive(op->getBlock(), successor);
+    }
+  }
+
+  return success();
+}
+
+void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) {
+  Operation *callableOp = nullptr;
+  if (Value callableValue = call.getCallableForCallee().dyn_cast<Value>())
+    callableOp = callableValue.getDefiningOp();
+  else
+    callableOp = call.resolveCallable(&symbolTable);
+
+  // A call to a externally-defined callable has unknown predecessors.
+  const auto isExternalCallable = [](Operation *op) {
+    if (auto callable = dyn_cast<CallableOpInterface>(op))
+      return !callable.getCallableRegion();
+    return false;
+  };
+
+  // TODO: Add support for non-symbol callables when necessary. If the
+  // callable has non-call uses we would mark as having reached pessimistic
+  // fixpoint, otherwise allow for propagating the return values out.
+  if (isa_and_nonnull<SymbolOpInterface>(callableOp) &&
+      !isExternalCallable(callableOp)) {
+    // Add the live callsite.
+    auto *callsites = getOrCreate<PredecessorState>(callableOp);
+    propagateIfChanged(callsites, callsites->join(call));
+  } else {
+    // Mark this call op's predecessors as overdefined.
+    auto *predecessors = getOrCreate<PredecessorState>(call);
+    propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors());
+  }
+}
+
+/// Get the constant values of the operands of an operation. If any of the
+/// constant value lattices are uninitialized, return none to indicate the
+/// analysis should bail out.
+static Optional<SmallVector<Attribute>> getOperandValuesImpl(
+    Operation *op,
+    function_ref<const Lattice<ConstantValue> *(Value)> getLattice) {
+  SmallVector<Attribute> operands;
+  operands.reserve(op->getNumOperands());
+  for (Value operand : op->getOperands()) {
+    const Lattice<ConstantValue> *cv = getLattice(operand);
+    // If any of the operands' values are uninitialized, bail out.
+    if (cv->isUninitialized())
+      return {};
+    operands.push_back(cv->getValue().getConstantValue());
+  }
+  return operands;
+}
+
+Optional<SmallVector<Attribute>>
+DeadCodeAnalysis::getOperandValues(Operation *op) {
+  return getOperandValuesImpl(op, [&](Value value) {
+    auto *lattice = getOrCreate<Lattice<ConstantValue>>(value);
+    lattice->useDefSubscribe(this);
+    return lattice;
+  });
+}
+
+void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) {
+  // Try to deduce a single successor for the branch.
+  Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
+  if (!operands)
+    return;
+
+  if (Block *successor = branch.getSuccessorForOperands(*operands)) {
+    markEdgeLive(branch->getBlock(), successor);
+  } else {
+    // Otherwise, mark all successors as executable and outgoing edges.
+    for (Block *successor : branch->getSuccessors())
+      markEdgeLive(branch->getBlock(), successor);
+  }
+}
+
+void DeadCodeAnalysis::visitRegionBranchOperation(
+    RegionBranchOpInterface branch) {
+  // Try to deduce which regions are executable.
+  Optional<SmallVector<Attribute>> operands = getOperandValues(branch);
+  if (!operands)
+    return;
+
+  SmallVector<RegionSuccessor> successors;
+  branch.getSuccessorRegions(/*index=*/{}, *operands, successors);
+
+  for (const RegionSuccessor &successor : successors) {
+    // Mark the entry block as executable.
+    Region *region = successor.getSuccessor();
+    assert(region && "expected a region successor");
+    auto *state = getOrCreate<Executable>(&region->front());
+    propagateIfChanged(state, state->setToLive());
+    // Add the parent op as a predecessor.
+    auto *predecessors = getOrCreate<PredecessorState>(&region->front());
+    propagateIfChanged(predecessors, predecessors->join(branch));
+  }
+}
+
+void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
+                                             RegionBranchOpInterface branch) {
+  Optional<SmallVector<Attribute>> operands = getOperandValues(op);
+  if (!operands)
+    return;
+
+  SmallVector<RegionSuccessor> successors;
+  branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(),
+                             *operands, successors);
+
+  // Mark successor region entry blocks as executable and add this op to the
+  // list of predecessors.
+  for (const RegionSuccessor &successor : successors) {
+    PredecessorState *predecessors;
+    if (Region *region = successor.getSuccessor()) {
+      auto *state = getOrCreate<Executable>(&region->front());
+      propagateIfChanged(state, state->setToLive());
+      predecessors = getOrCreate<PredecessorState>(&region->front());
+    } else {
+      // Add this terminator as a predecessor to the parent op.
+      predecessors = getOrCreate<PredecessorState>(branch);
+    }
+    propagateIfChanged(predecessors, predecessors->join(op));
+  }
+}
+
+void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
+                                               CallableOpInterface callable) {
+  // If there are no exiting values, we have nothing to do.
+  if (op->getNumOperands() == 0)
+    return;
+
+  // Add as predecessors to all callsites this return op.
+  auto *callsites = getOrCreateFor<PredecessorState>(op, callable);
+  bool canResolve = op->hasTrait<OpTrait::ReturnLike>();
+  for (Operation *predecessor : callsites->getKnownPredecessors()) {
+    assert(isa<CallOpInterface>(predecessor));
+    auto *predecessors = getOrCreate<PredecessorState>(predecessor);
+    if (canResolve) {
+      propagateIfChanged(predecessors, predecessors->join(op));
+    } else {
+      // If the terminator is not a return-like, then conservatively assume we
+      // can't resolve the predecessor.
+      propagateIfChanged(predecessors,
+                         predecessors->setHasUnknownPredecessors());
+    }
+  }
+}

diff  --git a/mlir/test/Analysis/test-dead-code-analysis.mlir b/mlir/test/Analysis/test-dead-code-analysis.mlir
new file mode 100644
index 0000000000000..2668b8349d1ca
--- /dev/null
+++ b/mlir/test/Analysis/test-dead-code-analysis.mlir
@@ -0,0 +1,248 @@
+// RUN: mlir-opt -test-dead-code-analysis 2>&1 %s | FileCheck %s
+
+// CHECK: test_cfg:
+// CHECK:  region #0
+// CHECK:   ^bb0 = live
+// CHECK:   ^bb1 = live
+// CHECK:    from ^bb1 = live
+// CHECK:    from ^bb0 = live
+// CHECK:   ^bb2 = live
+// CHECK:    from ^bb1 = live
+func.func @test_cfg(%cond: i1) -> () 
+    attributes {tag = "test_cfg"} {
+  cf.br ^bb1
+
+^bb1:
+  cf.cond_br %cond, ^bb1, ^bb2
+
+^bb2:
+  return
+}
+
+func.func @test_region_control_flow(%cond: i1, %arg0: i64, %arg1: i64) -> () {
+  // CHECK: test_if:
+  // CHECK:  region #0
+  // CHECK: region_preds: (all) predecessors:
+  // CHECK:   scf.if 
+  // CHECK:  region #1
+  // CHECK: region_preds: (all) predecessors:
+  // CHECK:   scf.if 
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   scf.yield {then}
+  // CHECK:   scf.yield {else}
+  scf.if %cond {
+    scf.yield {then}
+  } else {
+    scf.yield {else}
+  } {tag = "test_if"} 
+
+  // test_while:
+  //  region #0
+  // region_preds: (all) predecessors:
+  //   scf.while 
+  //   scf.yield 
+  //  region #1
+  // region_preds: (all) predecessors:
+  //   scf.condition
+  // op_preds: (all) predecessors:
+  //   scf.condition
+  %c2_i64 = arith.constant 2 : i64
+  %0:2 = scf.while (%arg2 = %arg0) : (i64) -> (i64, i64) {
+    %1 = arith.cmpi slt, %arg2, %arg1 : i64
+    scf.condition(%1) %arg2, %arg2 : i64, i64
+  } do {
+  ^bb0(%arg2: i64, %arg3: i64):
+    %1 = arith.muli %arg3, %c2_i64 : i64
+    scf.yield %1 : i64
+  } attributes {tag = "test_while"}
+
+  return
+}
+
+// CHECK: foo:
+// CHECK:  region #0
+// CHECK:   ^bb0 = live
+// CHECK: op_preds: (all) predecessors:
+// CHECK:   func.call @foo(%{{.*}}) {tag = "a"}
+// CHECK:   func.call @foo(%{{.*}}) {tag = "b"}
+func.func private @foo(%arg0: i32) -> i32 
+    attributes {tag = "foo"} {
+  return {a} %arg0 : i32 
+}
+
+// CHECK: bar:
+// CHECK:  region #0
+// CHECK:   ^bb0 = live
+// CHECK: op_preds: predecessors:
+// CHECK:   func.call @bar(%{{.*}}) {tag = "c"}
+func.func @bar(%cond: i1) -> i32 
+    attributes {tag = "bar"} {
+  cf.cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+  %c0 = arith.constant 0 : i32
+  return {b} %c0 : i32
+
+^bb2:
+  %c1 = arith.constant 1 : i32
+  return {c} %c1 : i32
+}
+
+// CHECK: baz
+// CHECK: op_preds: (all) predecessors:
+func.func private @baz(i32) -> i32 attributes {tag = "baz"}
+
+func.func @test_callgraph(%cond: i1, %arg0: i32) -> i32 {
+  // CHECK: a:
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   func.return {a}
+  %0 = func.call @foo(%arg0) {tag = "a"} : (i32) -> i32
+  cf.cond_br %cond, ^bb1, ^bb2
+
+^bb1:
+  // CHECK: b:
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   func.return {a}
+  %1 = func.call @foo(%arg0) {tag = "b"} : (i32) -> i32
+  return %1 : i32
+
+^bb2:
+  // CHECK: c:
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   func.return {b}
+  // CHECK:   func.return {c}
+  %2 = func.call @bar(%cond) {tag = "c"} : (i1) -> i32
+  // CHECK: d:
+  // CHECK: op_preds: predecessors:
+  %3 = func.call @baz(%arg0) {tag = "d"} : (i32) -> i32
+  return %2 : i32
+}
+
+// CHECK: test_unknown_branch:
+// CHECK:  region #0
+// CHECK:   ^bb0 = live
+// CHECK:   ^bb1 = live
+// CHECK:    from ^bb0 = live
+// CHECK:   ^bb2 = live
+// CHECK:    from ^bb0 = live
+func.func @test_unknown_branch() -> () 
+    attributes {tag = "test_unknown_branch"} {
+  "test.unknown_br"() [^bb1, ^bb2] : () -> ()
+
+^bb1:
+  return
+
+^bb2:
+  return
+}
+
+// CHECK: test_unknown_region:
+// CHECK:  region #0
+// CHECK:   ^bb0 = live
+// CHECK:  region #1
+// CHECK:   ^bb0 = live
+func.func @test_unknown_region() -> () {
+  "test.unknown_region_br"() ({
+  ^bb0:
+    "test.unknown_region_end"() : () -> ()
+  }, {    
+  ^bb0:
+    "test.unknown_region_end"() : () -> ()
+  }) {tag = "test_unknown_region"} : () -> ()
+  return
+}
+
+// CHECK: test_known_dead_block:
+// CHECK:  region #0
+// CHECK:   ^bb0 = live
+// CHECK:   ^bb1 = live
+// CHECK:   ^bb2 = dead
+func.func @test_known_dead_block() -> () 
+    attributes {tag = "test_known_dead_block"} {
+  %true = arith.constant true
+  cf.cond_br %true, ^bb1, ^bb2
+
+^bb1:
+  return
+
+^bb2:
+  return
+}
+
+// CHECK: test_known_dead_edge:
+// CHECK:   ^bb2 = live
+// CHECK:    from ^bb1 = dead
+// CHECK:    from ^bb0 = live
+func.func @test_known_dead_edge(%arg0: i1) -> ()
+    attributes {tag = "test_known_dead_edge"} {
+  cf.cond_br %arg0, ^bb1, ^bb2
+
+^bb1:
+  %true = arith.constant true 
+  cf.cond_br %true, ^bb3, ^bb2
+
+^bb2:
+  return
+
+^bb3:
+  return
+}
+
+func.func @test_known_region_predecessors() -> () {
+  %false = arith.constant false
+  // CHECK: test_known_if:
+  // CHECK:  region #0
+  // CHECK:   ^bb0 = dead
+  // CHECK:  region #1
+  // CHECK:   ^bb0 = live
+  // CHECK: region_preds: (all) predecessors:
+  // CHECK:   scf.if 
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   scf.yield {else}
+  scf.if %false {
+    scf.yield {then}
+  } else {
+    scf.yield {else}
+  } {tag = "test_known_if"} 
+  return
+}
+
+// CHECK: callable:
+// CHECK:  region #0
+// CHECK:   ^bb0 = live
+// CHECK: op_preds: predecessors:
+// CHECK:   func.call @callable() {then}
+func.func @callable() attributes {tag = "callable"} {
+  return
+}
+
+func.func @test_dead_callsite() -> () {
+  %true = arith.constant true
+  scf.if %true {
+    func.call @callable() {then} : () -> ()
+    scf.yield 
+  } else {
+    func.call @callable() {else} : () -> ()
+    scf.yield 
+  }
+  return
+}
+
+func.func private @test_dead_return(%arg0: i32) -> i32 {
+  %true = arith.constant true
+  cf.cond_br %true, ^bb1, ^bb1
+
+^bb1:
+  return {true} %arg0 : i32
+
+^bb2:
+  return {false} %arg0 : i32
+}
+
+func.func @test_call_dead_return(%arg0: i32) -> () {
+  // CHECK: test_dead_return:
+  // CHECK: op_preds: (all) predecessors:
+  // CHECK:   func.return {true}
+  %0 = func.call @test_dead_return(%arg0) {tag = "test_dead_return"} : (i32) -> i32
+  return
+}

diff  --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt
index d0b9d2be4f6ea..128615659cf30 100644
--- a/mlir/test/lib/Analysis/CMakeLists.txt
+++ b/mlir/test/lib/Analysis/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_library(MLIRTestAnalysis
   TestCallGraph.cpp
   TestDataFlow.cpp
   TestDataFlowFramework.cpp
+  TestDeadCodeAnalysis.cpp
   TestLiveness.cpp
   TestMatchReduction.cpp
   TestMemRefBoundCheck.cpp

diff  --git a/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp
new file mode 100644
index 0000000000000..a7c33526c80de
--- /dev/null
+++ b/mlir/test/lib/Analysis/TestDeadCodeAnalysis.cpp
@@ -0,0 +1,116 @@
+//===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SparseDataFlowAnalysis.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+/// Print the liveness of every block, control-flow edge, and the predecessors
+/// of all regions, callables, and calls.
+static void printAnalysisResults(DataFlowSolver &solver, Operation *op,
+                                 raw_ostream &os) {
+  op->walk([&](Operation *op) {
+    auto tag = op->getAttrOfType<StringAttr>("tag");
+    if (!tag)
+      return;
+    os << tag.getValue() << ":\n";
+    for (Region &region : op->getRegions()) {
+      os << " region #" << region.getRegionNumber() << "\n";
+      for (Block &block : region) {
+        os << "  ";
+        block.printAsOperand(os);
+        os << " = ";
+        auto *live = solver.lookupState<Executable>(&block);
+        if (live)
+          os << *live;
+        else
+          os << "dead";
+        os << "\n";
+        for (Block *pred : block.getPredecessors()) {
+          os << "   from ";
+          pred->printAsOperand(os);
+          os << " = ";
+          auto *live = solver.lookupState<Executable>(
+              solver.getProgramPoint<CFGEdge>(pred, &block));
+          if (live)
+            os << *live;
+          else
+            os << "dead";
+          os << "\n";
+        }
+      }
+      if (!region.empty()) {
+        auto *preds = solver.lookupState<PredecessorState>(&region.front());
+        if (preds)
+          os << "region_preds: " << *preds << "\n";
+      }
+    }
+    auto *preds = solver.lookupState<PredecessorState>(op);
+    if (preds)
+      os << "op_preds: " << *preds << "\n";
+  });
+}
+
+namespace {
+/// This is a simple analysis that implements a transfer function for constant
+/// operations.
+struct ConstantAnalysis : public DataFlowAnalysis {
+  using DataFlowAnalysis::DataFlowAnalysis;
+
+  LogicalResult initialize(Operation *top) override {
+    WalkResult result = top->walk([&](Operation *op) {
+      if (op->hasTrait<OpTrait::ConstantLike>())
+        if (failed(visit(op)))
+          return WalkResult::interrupt();
+      return WalkResult::advance();
+    });
+    return success(!result.wasInterrupted());
+  }
+
+  LogicalResult visit(ProgramPoint point) override {
+    Operation *op = point.get<Operation *>();
+    Attribute value;
+    if (matchPattern(op, m_Constant(&value))) {
+      auto *constant = getOrCreate<Lattice<ConstantValue>>(op->getResult(0));
+      propagateIfChanged(
+          constant, constant->join(ConstantValue(value, op->getDialect())));
+    }
+    return success();
+  }
+};
+
+/// This is a simple pass that runs dead code analysis with no constant value
+/// provider. It marks everything as live.
+struct TestDeadCodeAnalysisPass
+    : public PassWrapper<TestDeadCodeAnalysisPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDeadCodeAnalysisPass)
+
+  StringRef getArgument() const override { return "test-dead-code-analysis"; }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+
+    DataFlowSolver solver;
+    solver.load<DeadCodeAnalysis>();
+    solver.load<ConstantAnalysis>();
+    if (failed(solver.initializeAndRun(op)))
+      return signalPassFailure();
+    printAnalysisResults(solver, op, llvm::errs());
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestDeadCodeAnalysisPass() {
+  PassRegistration<TestDeadCodeAnalysisPass>();
+}
+} // end namespace test
+} // end namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 9e872ab63f5f4..ecd85a7f3aaa4 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -72,6 +72,7 @@ void registerTestGpuSerializeToCubinPass();
 void registerTestGpuSerializeToHsacoPass();
 void registerTestDataFlowPass();
 void registerTestDataLayoutQuery();
+void registerTestDeadCodeAnalysisPass();
 void registerTestDecomposeCallGraphTypes();
 void registerTestDiagnosticsPass();
 void registerTestDominancePass();
@@ -173,6 +174,7 @@ void registerTestPasses() {
   mlir::test::registerTestDecomposeCallGraphTypes();
   mlir::test::registerTestDataFlowPass();
   mlir::test::registerTestDataLayoutQuery();
+  mlir::test::registerTestDeadCodeAnalysisPass();
   mlir::test::registerTestDominancePass();
   mlir::test::registerTestDynamicPipelinePass();
   mlir::test::registerTestExpandMathPass();


        


More information about the llvm-branch-commits mailing list