[Mlir-commits] [mlir] d07c90e - [mlir] Refactor the forward dataflow propagation in SCCP into a generic framework
River Riddle
llvmlistbot at llvm.org
Mon Apr 26 19:40:18 PDT 2021
Author: River Riddle
Date: 2021-04-26T19:39:46-07:00
New Revision: d07c90e39550e6b708d9bd262697a4b92bae860a
URL: https://github.com/llvm/llvm-project/commit/d07c90e39550e6b708d9bd262697a4b92bae860a
DIFF: https://github.com/llvm/llvm-project/commit/d07c90e39550e6b708d9bd262697a4b92bae860a.diff
LOG: [mlir] Refactor the forward dataflow propagation in SCCP into a generic framework
This revision takes the forward value propagation engine in SCCP and refactors it into a more generalized forward dataflow analysis framework. This framework allows for propagating information about values across the various control flow constructs in MLIR, and removes the need for users to reinvent the traversal (often not as completely). There are a few aspects of the traversal, that were conservative for SCCP, that should be relaxed to support the needs of different value analyses. To keep this revision simple, these conservative behaviors will be left in (Note that this won't produce an incorrect result, but may produce more conservative results than necessary in certain edge cases. e.g. region entry arguments for non-region branch interface operations). The framework also only focuses on computing lattices for values, given the SCCP origins, but this is something to relax as needed in the future.
Given that this logic is already in SCCP, a majority of this commit is NFC. The more interesting parts are the interface glue that clients interact with.
Differential Revision: https://reviews.llvm.org/D100915
Added:
mlir/docs/Tutorials/DataFlowAnalysis.md
mlir/include/mlir/Analysis/DataFlowAnalysis.h
mlir/lib/Analysis/DataFlowAnalysis.cpp
Modified:
mlir/lib/Analysis/CMakeLists.txt
mlir/lib/Transforms/SCCP.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/DataFlowAnalysis.md b/mlir/docs/Tutorials/DataFlowAnalysis.md
new file mode 100644
index 0000000000000..ee8d9cf80c0e7
--- /dev/null
+++ b/mlir/docs/Tutorials/DataFlowAnalysis.md
@@ -0,0 +1,293 @@
+# Writing DataFlow Analyses in MLIR
+
+Writing dataflow analyses in MLIR, or well any compiler, can often seem quite
+daunting and/or complex. A dataflow analysis generally involves propagating
+information about the IR across various
diff erent types of control flow
+constructs, of which MLIR has many (Block-based branches, Region-based branches,
+CallGraph, etc), and it isn't always clear how best to go about performing the
+propagation. To help writing these types of analyses in MLIR, this document
+details several utilities that simplify the process and make it a bit more
+approachable.
+
+## Forward Dataflow Analysis
+
+One type of dataflow analysis is a forward propagation analysis. This type of
+analysis, as the name may suggest, propagates information forward (e.g. from
+definitions to uses). To provide a bit of concrete context, let's go over
+writing a simple forward dataflow analysis in MLIR. Let's say for this analysis
+that we want to propagate information about a special "metadata" dictionary
+attribute. The contents of this attribute are simply a set of metadata that
+describe a specific value, e.g. `metadata = { likes_pizza = true }`. We will
+collect the `metadata` for operations in the IR and propagate them about.
+
+### Lattices
+
+Before going into how one might setup the analysis itself, it is important to
+first introduce the concept of a `Lattice` and how we will use it for the
+analysis. A lattice represents all of the possible values or results of the
+analysis for a given value. A lattice element holds the set of information
+computed by the analysis for a given value, and is what gets propagated across
+the IR. For our analysis, this would correspond to the `metadata` dictionary
+attribute.
+
+Regardless of the value held within, every type of lattice contains two special
+element states:
+
+* `uninitialized`
+
+ - The element has not been initialized.
+
+* `top`/`overdefined`/`unknown`
+
+ - The element encompasses every possible value.
+ - This is a very conservative state, and essentially means "I can't make
+ any assumptions about the value, it could be anything"
+
+These two states are important when merging, or `join`ing as we will refer to it
+further in this document, information as part of the analysis. Lattice elements
+are `join`ed whenever there are two
diff erent source points, such as an argument
+to a block with multiple predecessors. One important note about the `join`
+operation, is that it is required to be monotonic (see the `join` method in the
+example below for more information). This ensures that `join`ing elements is
+consistent. The two special states mentioned above have unique properties during
+a `join`:
+
+* `uninitialized`
+
+ - If one of the elements is `uninitialized`, the other element is used.
+ - `uninitialized` in the context of a `join` essentially means "take the
+ other thing".
+
+* `top`/`overdefined`/`unknown`
+
+ - If one of the elements being joined is `overdefined`, the result is
+ `overdefined`.
+
+For our analysis in MLIR, we will need to define a class representing the value
+held by an element of the lattice used by our dataflow analysis:
+
+```c++
+/// The value of our lattice represents the inner structure of a DictionaryAttr,
+/// for the `metadata`.
+struct MetadataLatticeValue {
+ MetadataLatticeValue() = default;
+ /// Compute a lattice value from the provided dictionary.
+ MetadataLatticeValue(DictionaryAttr attr)
+ : metadata(attr.begin(), attr.end()) {}
+
+ /// Return a pessimistic value state, i.e. the `top`/`overdefined`/`unknown`
+ /// state, for our value type. The resultant state should not assume any
+ /// information about the state of the IR.
+ static MetadataLatticeValue getPessimisticValueState(MLIRContext *context) {
+ // The `top`/`overdefined`/`unknown` state is when we know nothing about any
+ // metadata, i.e. an empty dictionary.
+ return MetadataLatticeValue();
+ }
+ /// Return a pessimistic value state for our value type using only information
+ /// about the state of the provided IR. This is similar to the above method,
+ /// but may produce a slightly more refined result. This is okay, as the
+ /// information is already encoded as fact in the IR.
+ static MetadataLatticeValue getPessimisticValueState(Value value) {
+ // Check to see if the parent operation has metadata.
+ if (Operation *parentOp = value.getDefiningOp()) {
+ if (auto metadata = parentOp->getAttrOfType<DictionaryAttr>("metadata"))
+ return MetadataLatticeValue(metadata);
+
+ // If no metadata is present, fallback to the
+ // `top`/`overdefined`/`unknown` state.
+ }
+ return MetadataLatticeValue();
+ }
+
+ /// This method conservatively joins the information held by `lhs` and `rhs`
+ /// into a new value. This method is required to be monotonic. `monotonicity`
+ /// is implied by the satisfaction of the following axioms:
+ /// * idempotence: join(x,x) == x
+ /// * commutativity: join(x,y) == join(y,x)
+ /// * associativity: join(x,join(y,z)) == join(join(x,y),z)
+ ///
+ /// When the above axioms are satisfied, we achieve `monotonicity`:
+ /// * monotonicity: join(x, join(x,y)) == join(x,y)
+ static MetadataLatticeValue join(const MetadataLatticeValue &lhs,
+ const MetadataLatticeValue &rhs) {
+ // To join `lhs` and `rhs` we will define a simple policy, which is that we
+ // only keep information that is the same. This means that we only keep
+ // facts that are true in both.
+ MetadataLatticeValue result;
+ for (const auto &lhsIt : lhs) {
+ // As noted above, we only merge if the values are the same.
+ auto it = rhs.metadata.find(lhsIt.first);
+ if (it == rhs.metadata.end() || it->second != lhsIt.second)
+ continue;
+ result.insert(lhsIt);
+ }
+ return result;
+ }
+
+ /// A simple comparator that checks to see if this value is equal to the one
+ /// provided.
+ bool operator==(const MetadataLatticeValue &rhs) const {
+ if (metadata.size() != rhs.metadata.size())
+ return false;
+ // Check that the 'rhs' contains the same metadata.
+ return llvm::all_of(metadata, [&](auto &it) {
+ return rhs.metadata.count(it.second);
+ });
+ }
+
+ /// Our value represents the combined metadata, which is originally a
+ /// DictionaryAttr, so we use a map.
+ DenseMap<Identifier, Attribute> metadata;
+};
+```
+
+One interesting thing to note above is that we don't have an explicit method for
+the `uninitialized` state. This state is handled by the `LatticeElement` class,
+which manages a lattice value for a given IR entity. A quick overview of this
+class, and the API that will be interesting to us while writing our analysis, is
+shown below:
+
+```c++
+/// This class represents a lattice element holding a specific value of type
+/// `ValueT`.
+template <typename ValueT>
+class LatticeElement ... {
+public:
+ /// Return the value held by this element. This requires that a value is
+ /// known, i.e. not `uninitialized`.
+ ValueT &getValue();
+ const ValueT &getValue() const;
+
+ /// Join the information contained in the 'rhs' element into this
+ /// element. Returns if the state of the current element changed.
+ ChangeResult join(const LatticeElement<ValueT> &rhs);
+
+ /// Join the information contained in the 'rhs' value into this
+ /// lattice. Returns if the state of the current lattice changed.
+ ChangeResult join(const ValueT &rhs);
+
+ /// Mark the lattice element as having reached a pessimistic fixpoint. This
+ /// means that the lattice may potentially have conflicting value states, and
+ /// only the conservatively known value state should be relied on.
+ ChangeResult markPessimisticFixPoint();
+};
+```
+
+With our lattice defined, we can now define the driver that will compute and
+propagate our lattice across the IR.
+
+### ForwardDataflowAnalysis Driver
+
+The `ForwardDataFlowAnalysis` class represents the driver of the dataflow
+analysis, and performs all of the related analysis computation. When defining
+our analysis, we will inherit from this class and implement some of its hooks.
+Before that, let's look at a quick overview of this class and some of the
+important API for our analysis:
+
+```c++
+/// This class represents the main driver of the forward dataflow analysis. It
+/// takes as a template parameter the value type of lattice being computed.
+template <typename ValueT>
+class ForwardDataFlowAnalysis : ... {
+public:
+ ForwardDataFlowAnalysis(MLIRContext *context);
+
+ /// Compute the analysis on operations rooted under the given top-level
+ /// operation. Note that the top-level operation is not visited.
+ void run(Operation *topLevelOp);
+
+ /// Return the lattice element attached to the given value. If a lattice has
+ /// not been added for the given value, a new 'uninitialized' value is
+ /// inserted and returned.
+ LatticeElement<ValueT> &getLatticeElement(Value value);
+
+ /// Return the lattice element attached to the given value, or nullptr if no
+ /// lattice element for the value has yet been created.
+ LatticeElement<ValueT> *lookupLatticeElement(Value value);
+
+ /// Mark all of the lattice elements for the given range of Values as having
+ /// reached a pessimistic fixpoint.
+ ChangeResult markAllPessimisticFixPoint(ValueRange values);
+
+protected:
+ /// Visit the given operation, and join any necessary analysis state
+ /// into the lattice elements for the results and block arguments owned by
+ /// this operation using the provided set of operand lattice elements
+ /// (all pointer values are guaranteed to be non-null). Returns if any result
+ /// or block argument value lattice elements changed during the visit. The
+ /// lattice element for a result or block argument value can be obtained, and
+ /// join'ed into, by using `getLatticeElement`.
+ virtual ChangeResult visitOperation(
+ Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) = 0;
+};
+```
+
+NOTE: Some API has been redacted for our example. The `ForwardDataFlowAnalysis`
+contains various other hooks that allow for injecting custom behavior when
+applicable.
+
+The main API that we are responsible for defining is the `visitOperation`
+method. This method is responsible for computing new lattice elements for the
+results and block arguments owned by the given operation. This is where we will
+inject the lattice element computation logic, also known as the transfer
+function for the operation, that is specific to our analysis. A simple
+implementation for our example is shown below:
+
+```c++
+class MetadataAnalysis : public ForwardDataFlowAnalysis<MetadataLatticeValue> {
+public:
+ using ForwardDataFlowAnalysis<MetadataLatticeValue>::ForwardDataFlowAnalysis;
+
+ ChangeResult visitOperation(
+ Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) override {
+ DictionaryAttr metadata = op->getAttrOfType<DictionaryAttr>("metadata");
+
+ // If we have no metadata for this operation, we will conservatively mark
+ // all of the results as having reached a pessimistic fixpoint.
+ if (!metadata)
+ return markAllPessimisticFixPoint(op->getResults());
+
+ // Otherwise, we will compute a lattice value for the metadata and join it
+ // into the current lattice element for all of our results.
+ MetadataLatticeValue latticeValue(metadata);
+ ChangeResult result = ChangeResult::NoChange;
+ for (Value value : op->getResults()) {
+ // We grab the lattice element for `value` via `getLatticeElement` and
+ // then join it with the lattice value for this operation's metadata. Note
+ // that during the analysis phase, it is fine to freely create a new
+ // lattice element for a value. This is why we don't use the
+ // `lookupLatticeElement` method here.
+ result |= getLatticeElement(value).join(latticeValue);
+ }
+ return result;
+ }
+};
+```
+
+With that, we have all of the necessary components to compute our analysis.
+After the analysis has been computed, we can grab any computed information for
+values by using `lookupLatticeElement`. We use this function over
+`getLatticeElement` as the analysis is not guaranteed to visit all values, e.g.
+if the value is in a unreachable block, and we don't want to create a new
+uninitialized lattice element in this case. See below for a quick example:
+
+```c++
+void MyPass::runOnOperation() {
+ MetadataAnalysis analysis(&getContext());
+ analysis.run(getOperation());
+ ...
+}
+
+void MyPass::useAnalysisOn(MetadataAnalysis &analysis, Value value) {
+ LatticeElement<MetadataLatticeValue> *latticeElement = analysis.lookupLatticeElement(value);
+
+ // If we don't have an element, the `value` wasn't visited during our analysis
+ // meaning that it could be dead. We need to treat this conservatively.
+ if (!lattice)
+ return;
+
+ // Our lattice element has a value, use it:
+ MetadataLatticeValue &value = lattice->getValue();
+ ...
+}
+```
diff --git a/mlir/include/mlir/Analysis/DataFlowAnalysis.h b/mlir/include/mlir/Analysis/DataFlowAnalysis.h
new file mode 100644
index 0000000000000..cc83947c84f41
--- /dev/null
+++ b/mlir/include/mlir/Analysis/DataFlowAnalysis.h
@@ -0,0 +1,401 @@
+//===- DataFlowAnalysis.h - General DataFlow Analysis Utilities -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files several utilities and algorithms that perform abstract dataflow
+// analysis over the IR. These allow for users to hook into various analysis
+// propagation algorithms without needing to reinvent the traveral over the
+//
diff erent types of control structures present within MLIR, such as regions,
+// the callgraph, etc. A few of the main entry points are detailed below:
+//
+// FowardDataFlowAnalysis:
+// This class provides support for defining dataflow algorithms that are
+// forward, sparse, pessimistic (except along unreached backedges) and
+// context-insensitive for the interprocedural aspects.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_DATAFLOWANALYSIS_H
+#define MLIR_ANALYSIS_DATAFLOWANALYSIS_H
+
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/Support/Allocator.h"
+
+namespace mlir {
+//===----------------------------------------------------------------------===//
+// ChangeResult
+//===----------------------------------------------------------------------===//
+
+/// A result type used to indicate if a change happened. Boolean operations on
+/// ChangeResult behave as though `Change` is truthy.
+enum class ChangeResult {
+ NoChange,
+ Change,
+};
+inline ChangeResult operator|(ChangeResult lhs, ChangeResult rhs) {
+ return lhs == ChangeResult::Change ? lhs : rhs;
+}
+inline ChangeResult &operator|=(ChangeResult &lhs, ChangeResult rhs) {
+ lhs = lhs | rhs;
+ return lhs;
+}
+inline ChangeResult operator&(ChangeResult lhs, ChangeResult rhs) {
+ return lhs == ChangeResult::NoChange ? lhs : rhs;
+}
+
+//===----------------------------------------------------------------------===//
+// AbstractLatticeElement
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// This class represents an abstract lattice. A lattice is what gets propagated
+/// across the IR, and contains the information for a specific Value.
+class AbstractLatticeElement {
+public:
+ virtual ~AbstractLatticeElement();
+
+ /// Returns true if the value of this lattice is uninitialized, meaning that
+ /// it hasn't yet been initialized.
+ virtual bool isUninitialized() const = 0;
+
+ /// Join the information contained in 'rhs' into this lattice. Returns
+ /// if the value of the lattice changed.
+ virtual ChangeResult join(const AbstractLatticeElement &rhs) = 0;
+
+ /// Mark the lattice element as having reached a pessimistic fixpoint. This
+ /// means that the lattice may potentially have conflicting value states, and
+ /// only the most conservative value should be relied on.
+ virtual ChangeResult markPessimisticFixpoint() = 0;
+
+ /// Mark the lattice element as having reached an optimistic fixpoint. This
+ /// means that we optimisticly assume the current value is the true state.
+ virtual void markOptimisticFixpoint() = 0;
+
+ /// Returns true if the lattice has reached a fixpoint. A fixpoint is when the
+ /// information optimistically assumed to be true is the same as the
+ /// information known to be true.
+ virtual bool isAtFixpoint() const = 0;
+};
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// LatticeElement
+//===----------------------------------------------------------------------===//
+
+/// This class represents a lattice holding a specific value of type `ValueT`.
+/// Lattice values (`ValueT`) are required to adhere to the following:
+/// * static ValueT join(const ValueT &lhs, const ValueT &rhs);
+/// - This method conservatively joins the information held by `lhs`
+/// and `rhs` into a new value. This method is required to be monotonic.
+/// * static ValueT getPessimisticValueState(MLIRContext *context);
+/// - This method computes a pessimistic/conservative value state assuming
+/// no information about the state of the IR.
+/// * static ValueT getPessimisticValueState(Value value);
+/// - This method computes a pessimistic/conservative value state for
+/// `value` assuming only information present in the current IR.
+/// * bool operator==(const ValueT &rhs) const;
+///
+template <typename ValueT>
+class LatticeElement final : public detail::AbstractLatticeElement {
+public:
+ LatticeElement() = delete;
+ LatticeElement(const ValueT &knownValue) : knownValue(knownValue) {}
+
+ /// Return the value held by this lattice. This requires that the value is
+ /// initialized.
+ ValueT &getValue() {
+ assert(!isUninitialized() && "expected known lattice element");
+ return *optimisticValue;
+ }
+ const ValueT &getValue() const {
+ assert(!isUninitialized() && "expected known lattice element");
+ return *optimisticValue;
+ }
+
+ /// Returns true if the value of this lattice hasn't yet been initialized.
+ bool isUninitialized() const final { return !optimisticValue.hasValue(); }
+
+ /// Join the information contained in the 'rhs' lattice into this
+ /// lattice. Returns if the state of the current lattice changed.
+ ChangeResult join(const detail::AbstractLatticeElement &rhs) final {
+ const LatticeElement<ValueT> &rhsLattice =
+ static_cast<const LatticeElement<ValueT> &>(rhs);
+
+ // If we are at a fixpoint, or rhs is uninitialized, there is nothing to do.
+ if (isAtFixpoint() || rhsLattice.isUninitialized())
+ return ChangeResult::NoChange;
+
+ // Join the rhs value into this lattice.
+ return join(rhsLattice.getValue());
+ }
+
+ /// Join the information contained in the 'rhs' value into this
+ /// lattice. Returns if the state of the current lattice changed.
+ ChangeResult join(const ValueT &rhs) {
+ // If the current lattice is uninitialized, copy the rhs value.
+ if (isUninitialized()) {
+ optimisticValue = rhs;
+ return ChangeResult::Change;
+ }
+
+ // Otherwise, join rhs with the current optimistic value.
+ ValueT newValue = ValueT::join(*optimisticValue, rhs);
+ assert(ValueT::join(newValue, *optimisticValue) == newValue &&
+ "expected `join` to be monotonic");
+ assert(ValueT::join(newValue, rhs) == newValue &&
+ "expected `join` to be monotonic");
+
+ // Update the current optimistic value if something changed.
+ if (newValue == optimisticValue)
+ return ChangeResult::NoChange;
+
+ optimisticValue = newValue;
+ return ChangeResult::Change;
+ }
+
+ /// Mark the lattice element as having reached a pessimistic fixpoint. This
+ /// means that the lattice may potentially have conflicting value states, and
+ /// only the conservatively known value state should be relied on.
+ ChangeResult markPessimisticFixpoint() final {
+ if (isAtFixpoint())
+ return ChangeResult::NoChange;
+
+ // For this fixed point, we take whatever we knew to be true and set that to
+ // our optimistic value.
+ optimisticValue = knownValue;
+ return ChangeResult::Change;
+ }
+
+ /// Mark the lattice element as having reached an optimistic fixpoint. This
+ /// means that we optimisticly assume the current value is the true state.
+ void markOptimisticFixpoint() final {
+ assert(!isUninitialized() && "expected an initialized value");
+ knownValue = *optimisticValue;
+ }
+
+ /// Returns true if the lattice has reached a fixpoint. A fixpoint is when the
+ /// information optimistically assumed to be true is the same as the
+ /// information known to be true.
+ bool isAtFixpoint() const final { return optimisticValue == knownValue; }
+
+private:
+ /// The value that is conservatively known to be true.
+ ValueT knownValue;
+ /// The currently computed value that is optimistically assumed to be true, or
+ /// None if the lattice element is uninitialized.
+ Optional<ValueT> optimisticValue;
+};
+
+//===----------------------------------------------------------------------===//
+// ForwardDataFlowAnalysisBase
+//===----------------------------------------------------------------------===//
+
+namespace detail {
+/// This class is the non-templated virtual base class for the
+/// ForwardDataFlowAnalysis. This class provides opaque hooks to the main
+/// alogrithm.
+class ForwardDataFlowAnalysisBase {
+public:
+ virtual ~ForwardDataFlowAnalysisBase();
+
+ /// Initialize and compute the analysis on operations rooted under the given
+ /// top-level operation. Note that the top-level operation is not visited.
+ void run(Operation *topLevelOp);
+
+ /// Return the lattice element attached to the given value. If a lattice has
+ /// not been added for the given value, a new 'uninitialized' value is
+ /// inserted and returned.
+ AbstractLatticeElement &getLatticeElement(Value value);
+
+ /// Return the lattice element attached to the given value, or nullptr if no
+ /// lattice for the value has yet been created.
+ AbstractLatticeElement *lookupLatticeElement(Value value);
+
+ /// Visit the given operation, and join any necessary analysis state
+ /// into the lattices for the results and block arguments owned by this
+ /// operation using the provided set of operand lattice elements (all pointer
+ /// values are guaranteed to be non-null). Returns if any result or block
+ /// argument value lattices changed during the visit. The lattice for a result
+ /// or block argument value can be obtained and join'ed into by using
+ /// `getLatticeElement`.
+ virtual ChangeResult
+ visitOperation(Operation *op,
+ ArrayRef<AbstractLatticeElement *> operands) = 0;
+
+ /// Given a BranchOpInterface, and the current lattice elements that
+ /// correspond to the branch operands (all pointer values are guaranteed to be
+ /// non-null), try to compute a specific set of successors that would be
+ /// selected for the branch. Returns failure if not computable, or if all of
+ /// the successors would be chosen. If a subset of successors can be selected,
+ /// `successors` is populated.
+ virtual LogicalResult
+ getSuccessorsForOperands(BranchOpInterface branch,
+ ArrayRef<AbstractLatticeElement *> operands,
+ SmallVectorImpl<Block *> &successors) = 0;
+
+ /// Given a RegionBranchOpInterface, and the current lattice elements that
+ /// correspond to the branch operands (all pointer values are guaranteed to be
+ /// non-null), compute a specific set of region successors that would be
+ /// selected.
+ virtual void
+ getSuccessorsForOperands(RegionBranchOpInterface branch,
+ Optional<unsigned> sourceIndex,
+ ArrayRef<AbstractLatticeElement *> operands,
+ SmallVectorImpl<RegionSuccessor> &successors) = 0;
+
+ /// Create a new uninitialized lattice element. An optional value is provided
+ /// which, if valid, should be used to initialize the known conservative state
+ /// of the lattice.
+ virtual AbstractLatticeElement *createLatticeElement(Value value = {}) = 0;
+
+private:
+ /// A map from SSA value to lattice element.
+ DenseMap<Value, AbstractLatticeElement *> latticeValues;
+};
+} // namespace detail
+
+//===----------------------------------------------------------------------===//
+// ForwardDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// This class provides a general forward dataflow analyis driver
+/// utilizing the lattice classes defined above, to enable the easy definition
+/// of dataflow analysis algorithms. More specically this driver is useful for
+/// defining analyses that are forward, sparse, pessimistic (except along
+/// unreached backedges) and context-insensitive for the interprocedural
+/// aspects.
+template <typename ValueT>
+class ForwardDataFlowAnalysis : public detail::ForwardDataFlowAnalysisBase {
+public:
+ ForwardDataFlowAnalysis(MLIRContext *context) : context(context) {}
+
+ /// Return the MLIR context used when constructing this analysis.
+ MLIRContext *getContext() { return context; }
+
+ /// Compute the analysis on operations rooted under the given top-level
+ /// operation. Note that the top-level operation is not visited.
+ void run(Operation *topLevelOp) {
+ detail::ForwardDataFlowAnalysisBase::run(topLevelOp);
+ }
+
+ /// Return the lattice element attached to the given value, or nullptr if no
+ /// lattice for the value has yet been created.
+ LatticeElement<ValueT> *lookupLatticeElement(Value value) {
+ return static_cast<LatticeElement<ValueT> *>(
+ detail::ForwardDataFlowAnalysisBase::lookupLatticeElement(value));
+ }
+
+protected:
+ /// Return the lattice element attached to the given value. If a lattice has
+ /// not been added for the given value, a new 'uninitialized' value is
+ /// inserted and returned.
+ LatticeElement<ValueT> &getLatticeElement(Value value) {
+ return static_cast<LatticeElement<ValueT> &>(
+ detail::ForwardDataFlowAnalysisBase::getLatticeElement(value));
+ }
+
+ /// Mark all of the lattices for the given range of Values as having reached a
+ /// pessimistic fixpoint.
+ ChangeResult markAllPessimisticFixpoint(ValueRange values) {
+ ChangeResult result = ChangeResult::NoChange;
+ for (Value value : values)
+ result |= getLatticeElement(value).markPessimisticFixpoint();
+ return result;
+ }
+
+ /// Visit the given operation, and join any necessary analysis state
+ /// into the lattices for the results and block arguments owned by this
+ /// operation using the provided set of operand lattice elements (all pointer
+ /// values are guaranteed to be non-null). Returns if any result or block
+ /// argument value lattices changed during the visit. The lattice for a result
+ /// or block argument value can be obtained by using
+ /// `getLatticeElement`.
+ virtual ChangeResult
+ visitOperation(Operation *op,
+ ArrayRef<LatticeElement<ValueT> *> operands) = 0;
+
+ /// Given a BranchOpInterface, and the current lattice elements that
+ /// correspond to the branch operands (all pointer values are guaranteed to be
+ /// non-null), try to compute a specific set of successors that would be
+ /// selected for the branch. Returns failure if not computable, or if all of
+ /// the successors would be chosen. If a subset of successors can be selected,
+ /// `successors` is populated.
+ virtual LogicalResult
+ getSuccessorsForOperands(BranchOpInterface branch,
+ ArrayRef<LatticeElement<ValueT> *> operands,
+ SmallVectorImpl<Block *> &successors) {
+ return failure();
+ }
+
+ /// Given a RegionBranchOpInterface, and the current lattice elements that
+ /// correspond to the branch operands (all pointer values are guaranteed to be
+ /// non-null), compute a specific set of region successors that would be
+ /// selected.
+ virtual void
+ getSuccessorsForOperands(RegionBranchOpInterface branch,
+ Optional<unsigned> sourceIndex,
+ ArrayRef<LatticeElement<ValueT> *> operands,
+ SmallVectorImpl<RegionSuccessor> &successors) {
+ SmallVector<Attribute> constantOperands(operands.size());
+ branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
+ }
+
+private:
+ /// Type-erased wrappers that convert the abstract lattice operands to derived
+ /// lattices and invoke the virtual hooks operating on the derived lattices.
+ ChangeResult
+ visitOperation(Operation *op,
+ ArrayRef<detail::AbstractLatticeElement *> operands) final {
+ LatticeElement<ValueT> *const *derivedOperandBase =
+ reinterpret_cast<LatticeElement<ValueT> *const *>(operands.data());
+ return visitOperation(
+ op, llvm::makeArrayRef(derivedOperandBase, operands.size()));
+ }
+ LogicalResult
+ getSuccessorsForOperands(BranchOpInterface branch,
+ ArrayRef<detail::AbstractLatticeElement *> operands,
+ SmallVectorImpl<Block *> &successors) final {
+ LatticeElement<ValueT> *const *derivedOperandBase =
+ reinterpret_cast<LatticeElement<ValueT> *const *>(operands.data());
+ return getSuccessorsForOperands(
+ branch, llvm::makeArrayRef(derivedOperandBase, operands.size()),
+ successors);
+ }
+ void
+ getSuccessorsForOperands(RegionBranchOpInterface branch,
+ Optional<unsigned> sourceIndex,
+ ArrayRef<detail::AbstractLatticeElement *> operands,
+ SmallVectorImpl<RegionSuccessor> &successors) final {
+ LatticeElement<ValueT> *const *derivedOperandBase =
+ reinterpret_cast<LatticeElement<ValueT> *const *>(operands.data());
+ getSuccessorsForOperands(
+ branch, sourceIndex,
+ llvm::makeArrayRef(derivedOperandBase, operands.size()), successors);
+ }
+
+ /// Create a new uninitialized lattice element. An optional value is provided,
+ /// which if valid, should be used to initialize the known conservative state
+ /// of the lattice.
+ detail::AbstractLatticeElement *createLatticeElement(Value value) final {
+ ValueT knownValue = value ? ValueT::getPessimisticValueState(value)
+ : ValueT::getPessimisticValueState(context);
+ return new (allocator.Allocate()) LatticeElement<ValueT>(knownValue);
+ }
+
+ /// An allocator used for new lattice elements.
+ llvm::SpecificBumpPtrAllocator<LatticeElement<ValueT>> allocator;
+
+ /// The MLIRContext of this solver.
+ MLIRContext *context;
+};
+
+} // end namespace mlir
+
+#endif // MLIR_ANALYSIS_DATAFLOWANALYSIS_H
diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt
index 6333f1a61c34a..0faf58e489687 100644
--- a/mlir/lib/Analysis/CMakeLists.txt
+++ b/mlir/lib/Analysis/CMakeLists.txt
@@ -4,6 +4,7 @@ set(LLVM_OPTIONAL_SOURCES
AffineStructures.cpp
BufferAliasAnalysis.cpp
CallGraph.cpp
+ DataFlowAnalysis.cpp
LinearTransform.cpp
Liveness.cpp
LoopAnalysis.cpp
@@ -20,6 +21,7 @@ add_mlir_library(MLIRAnalysis
AliasAnalysis.cpp
BufferAliasAnalysis.cpp
CallGraph.cpp
+ DataFlowAnalysis.cpp
Liveness.cpp
NumberOfExecutions.cpp
SliceAnalysis.cpp
diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp
new file mode 100644
index 0000000000000..ec1a13a742426
--- /dev/null
+++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp
@@ -0,0 +1,780 @@
+//===- DataFlowAnalysis.cpp -----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/ADT/SmallPtrSet.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+namespace {
+/// This class contains various state used when computing the lattice elements
+/// of a callable operation.
+class CallableLatticeState {
+public:
+ /// Build a lattice state with a given callable region, and a specified number
+ /// of results to be initialized to the default lattice element.
+ CallableLatticeState(ForwardDataFlowAnalysisBase &analysis,
+ Region *callableRegion, unsigned numResults)
+ : callableArguments(callableRegion->getArguments()),
+ resultLatticeElements(numResults) {
+ for (AbstractLatticeElement *&it : resultLatticeElements)
+ it = analysis.createLatticeElement();
+ }
+
+ /// Returns the arguments to the callable region.
+ Block::BlockArgListType getCallableArguments() const {
+ return callableArguments;
+ }
+
+ /// Returns the lattice element for the results of the callable region.
+ auto getResultLatticeElements() {
+ return llvm::make_pointee_range(resultLatticeElements);
+ }
+
+ /// Add a call to this callable. This is only used if the callable defines a
+ /// symbol.
+ void addSymbolCall(Operation *op) { symbolCalls.push_back(op); }
+
+ /// Return the calls that reference this callable. This is only used
+ /// if the callable defines a symbol.
+ ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; }
+
+private:
+ /// The arguments of the callable region.
+ Block::BlockArgListType callableArguments;
+
+ /// The lattice state for each of the results of this region. The return
+ /// values of the callable aren't SSA values, so we need to track them
+ /// separately.
+ SmallVector<AbstractLatticeElement *, 4> resultLatticeElements;
+
+ /// The calls referencing this callable if this callable defines a symbol.
+ /// This removes the need to recompute symbol references during propagation.
+ /// Value based references are trivial to resolve, so they can be done
+ /// in-place.
+ SmallVector<Operation *, 4> symbolCalls;
+};
+
+/// This class represents the solver for a forward dataflow analysis. This class
+/// acts as the propagation engine for computing which lattice elements.
+class ForwardDataFlowSolver {
+public:
+ /// Initialize the solver with the given top-level operation.
+ ForwardDataFlowSolver(ForwardDataFlowAnalysisBase &analysis, Operation *op);
+
+ /// Run the solver until it converges.
+ void solve();
+
+private:
+ /// Initialize the set of symbol defining callables that can have their
+ /// arguments and results tracked. 'op' is the top-level operation that the
+ /// solver is operating on.
+ void initializeSymbolCallables(Operation *op);
+
+ /// Visit the users of the given IR that reside within executable blocks.
+ template <typename T>
+ void visitUsers(T &value) {
+ for (Operation *user : value.getUsers())
+ if (isBlockExecutable(user->getBlock()))
+ visitOperation(user);
+ }
+
+ /// Visit the given operation and compute any necessary lattice state.
+ void visitOperation(Operation *op);
+
+ /// Visit the given call operation and compute any necessary lattice state.
+ void visitCallOperation(CallOpInterface op);
+
+ /// Visit the given callable operation and compute any necessary lattice
+ /// state.
+ void visitCallableOperation(Operation *op);
+
+ /// Visit the given region branch operation, which defines regions, and
+ /// compute any necessary lattice state. This also resolves the lattice state
+ /// of both the operation results and any nested regions.
+ void visitRegionBranchOperation(
+ RegionBranchOpInterface branch,
+ ArrayRef<AbstractLatticeElement *> operandLattices);
+
+ /// Visit the given set of region successors, computing any necessary lattice
+ /// state. The provided function returns the input operands to the region at
+ /// the given index. If the index is 'None', the input operands correspond to
+ /// the parent operation results.
+ void visitRegionSuccessors(
+ Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
+ function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
+
+ /// Visit the given terminator operation and compute any necessary lattice
+ /// state.
+ void
+ visitTerminatorOperation(Operation *op,
+ ArrayRef<AbstractLatticeElement *> operandLattices);
+
+ /// Visit the given terminator operation that exits a callable region. These
+ /// are terminators with no CFG successors.
+ void visitCallableTerminatorOperation(
+ Operation *callable, Operation *terminator,
+ ArrayRef<AbstractLatticeElement *> operandLattices);
+
+ /// Visit the given block and compute any necessary lattice state.
+ void visitBlock(Block *block);
+
+ /// Visit argument #'i' of the given block and compute any necessary lattice
+ /// state.
+ void visitBlockArgument(Block *block, int i);
+
+ /// Mark the entry block of the given region as executable. Returns NoChange
+ /// if the block was already marked executable. If `markPessimisticFixpoint`
+ /// is true, the arguments of the entry block are also marked as having
+ /// reached the pessimistic fixpoint.
+ ChangeResult markEntryBlockExecutable(Region *region,
+ bool markPessimisticFixpoint);
+
+ /// Mark the given block as executable. Returns NoChange if the block was
+ /// already marked executable.
+ ChangeResult markBlockExecutable(Block *block);
+
+ /// Returns true if the given block is executable.
+ bool isBlockExecutable(Block *block) const;
+
+ /// Mark the edge between 'from' and 'to' as executable.
+ void markEdgeExecutable(Block *from, Block *to);
+
+ /// Return true if the edge between 'from' and 'to' is executable.
+ bool isEdgeExecutable(Block *from, Block *to) const;
+
+ /// Mark the given value as having reached the pessimistic fixpoint. This
+ /// means that we cannot further refine the state of this value.
+ void markPessimisticFixpoint(Value value);
+
+ /// Mark all of the given values as having reaching the pessimistic fixpoint.
+ template <typename ValuesT>
+ void markAllPessimisticFixpoint(ValuesT values) {
+ for (auto value : values)
+ markPessimisticFixpoint(value);
+ }
+ template <typename ValuesT>
+ void markAllPessimisticFixpoint(Operation *op, ValuesT values) {
+ markAllPessimisticFixpoint(values);
+ opWorklist.push_back(op);
+ }
+ template <typename ValuesT>
+ void markAllPessimisticFixpointAndVisitUsers(ValuesT values) {
+ for (auto value : values) {
+ AbstractLatticeElement &lattice = analysis.getLatticeElement(value);
+ if (lattice.markPessimisticFixpoint() == ChangeResult::Change)
+ visitUsers(value);
+ }
+ }
+
+ /// Returns true if the given value was marked as having reached the
+ /// pessimistic fixpoint.
+ bool isAtFixpoint(Value value) const;
+
+ /// Merge in the given lattice 'from' into the lattice 'to'. 'owner'
+ /// corresponds to the parent operation of the lattice for 'to'.
+ void join(Operation *owner, AbstractLatticeElement &to,
+ const AbstractLatticeElement &from);
+
+ /// A reference to the dataflow analysis being computed.
+ ForwardDataFlowAnalysisBase &analysis;
+
+ /// The set of blocks that are known to execute, or are intrinsically live.
+ SmallPtrSet<Block *, 16> executableBlocks;
+
+ /// The set of control flow edges that are known to execute.
+ DenseSet<std::pair<Block *, Block *>> executableEdges;
+
+ /// A worklist containing blocks that need to be processed.
+ SmallVector<Block *, 64> blockWorklist;
+
+ /// A worklist of operations that need to be processed.
+ SmallVector<Operation *, 64> opWorklist;
+
+ /// The callable operations that have their argument/result state tracked.
+ DenseMap<Operation *, CallableLatticeState> callableLatticeState;
+
+ /// A map between a call operation and the resolved symbol callable. This
+ /// avoids re-resolving symbol references during propagation. Value based
+ /// callables are trivial to resolve, so they can be done in-place.
+ DenseMap<Operation *, Operation *> callToSymbolCallable;
+
+ /// A symbol table used for O(1) symbol lookups during simplification.
+ SymbolTableCollection symbolTable;
+};
+} // end anonymous namespace
+
+ForwardDataFlowSolver::ForwardDataFlowSolver(
+ ForwardDataFlowAnalysisBase &analysis, Operation *op)
+ : analysis(analysis) {
+ /// Initialize the solver with the regions within this operation.
+ for (Region ®ion : op->getRegions()) {
+ // Mark the entry block as executable. The values passed to these regions
+ // are also invisible, so mark any arguments as reaching the pessimistic
+ // fixpoint.
+ markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true);
+ }
+ initializeSymbolCallables(op);
+}
+
+void ForwardDataFlowSolver::solve() {
+ while (!blockWorklist.empty() || !opWorklist.empty()) {
+ // Process any operations in the op worklist.
+ while (!opWorklist.empty())
+ visitUsers(*opWorklist.pop_back_val());
+
+ // Process any blocks in the block worklist.
+ while (!blockWorklist.empty())
+ visitBlock(blockWorklist.pop_back_val());
+ }
+}
+
+void ForwardDataFlowSolver::initializeSymbolCallables(Operation *op) {
+ // Initialize the set of symbol callables that can have their state tracked.
+ // This tracks which symbol callable operations we can propagate within and
+ // out of.
+ auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
+ Region &symbolTableRegion = symTable->getRegion(0);
+ Block *symbolTableBlock = &symbolTableRegion.front();
+ for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
+ // We won't be able to track external callables.
+ Region *callableRegion = callable.getCallableRegion();
+ if (!callableRegion)
+ continue;
+ // We only care about symbol defining callables here.
+ auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
+ if (!symbol)
+ continue;
+ callableLatticeState.try_emplace(callable, analysis, callableRegion,
+ callable.getCallableResults().size());
+
+ // If not all of the uses of this symbol are visible, we can't track the
+ // state of the arguments.
+ if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
+ for (Region ®ion : callable->getRegions())
+ markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true);
+ }
+ }
+ if (callableLatticeState.empty())
+ return;
+
+ // After computing the valid callables, walk any symbol uses to check
+ // for non-call references. We won't be able to track the lattice state
+ // for arguments to these callables, as we can't guarantee that we can see
+ // all of its calls.
+ Optional<SymbolTable::UseRange> uses =
+ SymbolTable::getSymbolUses(&symbolTableRegion);
+ if (!uses) {
+ // If we couldn't gather the symbol uses, conservatively assume that
+ // we can't track information for any nested symbols.
+ op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); });
+ return;
+ }
+
+ for (const SymbolTable::SymbolUse &use : *uses) {
+ // If the use is a call, track it to avoid the need to recompute the
+ // reference later.
+ if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
+ Operation *symCallable = callOp.resolveCallable(&symbolTable);
+ auto callableLatticeIt = callableLatticeState.find(symCallable);
+ if (callableLatticeIt != callableLatticeState.end()) {
+ callToSymbolCallable.try_emplace(callOp, symCallable);
+
+ // We only need to record the call in the lattice if it produces any
+ // values.
+ if (callOp->getNumResults())
+ callableLatticeIt->second.addSymbolCall(callOp);
+ }
+ continue;
+ }
+ // This use isn't a call, so don't we know all of the callers.
+ auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
+ auto it = callableLatticeState.find(symbol);
+ if (it != callableLatticeState.end()) {
+ for (Region ®ion : it->first->getRegions())
+ markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true);
+ }
+ }
+ };
+ SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
+ walkFn);
+}
+
+void ForwardDataFlowSolver::visitOperation(Operation *op) {
+ // Collect all of the lattice elements feeding into this operation. If any are
+ // not yet resolved, bail out and wait for them to resolve.
+ SmallVector<AbstractLatticeElement *, 8> operandLattices;
+ operandLattices.reserve(op->getNumOperands());
+ for (Value operand : op->getOperands()) {
+ AbstractLatticeElement *operandLattice =
+ analysis.lookupLatticeElement(operand);
+ if (!operandLattice)
+ return;
+ operandLattices.push_back(operandLattice);
+ }
+
+ // If this is a terminator operation, process any control flow lattice state.
+ if (op->hasTrait<OpTrait::IsTerminator>())
+ visitTerminatorOperation(op, operandLattices);
+
+ // Process call operations. The call visitor processes result values, so we
+ // can exit afterwards.
+ if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
+ return visitCallOperation(call);
+
+ // Process callable operations. These are specially handled region operations
+ // that track dataflow via calls.
+ if (isa<CallableOpInterface>(op)) {
+ // If this callable has a tracked lattice state, it will be visited by calls
+ // that reference it instead. This way, we don't assume that it is
+ // executable unless there is a proper reference to it.
+ if (callableLatticeState.count(op))
+ return;
+ return visitCallableOperation(op);
+ }
+
+ // Process region holding operations.
+ if (op->getNumRegions()) {
+ // Check to see if we can reason about the internal control flow of this
+ // region operation.
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(op))
+ return visitRegionBranchOperation(branch, operandLattices);
+
+ // If we can't, conservatively mark all regions as executable.
+ // TODO: Let the `visitOperation` method decide how to propagate
+ // information to the block arguments.
+ for (Region ®ion : op->getRegions())
+ markEntryBlockExecutable(®ion, /*markPessimisticFixpoint=*/true);
+ }
+
+ // If this op produces no results, it can't produce any constants.
+ if (op->getNumResults() == 0)
+ return;
+
+ // If all of the results of this operation are already resolved, bail out
+ // early.
+ auto isAtFixpointFn = [&](Value value) { return isAtFixpoint(value); };
+ if (llvm::all_of(op->getResults(), isAtFixpointFn))
+ return;
+
+ // Visit the current operation.
+ if (analysis.visitOperation(op, operandLattices) == ChangeResult::Change)
+ opWorklist.push_back(op);
+
+ // `visitOperation` is required to define all of the result lattices.
+ assert(llvm::none_of(
+ op->getResults(),
+ [&](Value value) {
+ return analysis.getLatticeElement(value).isUninitialized();
+ }) &&
+ "expected `visitOperation` to define all result lattices");
+}
+
+void ForwardDataFlowSolver::visitCallableOperation(Operation *op) {
+ // Mark the regions as executable. If we aren't tracking lattice state for
+ // this callable, mark all of the region arguments as having reached a
+ // fixpoint.
+ bool isTrackingLatticeState = callableLatticeState.count(op);
+ for (Region ®ion : op->getRegions())
+ markEntryBlockExecutable(®ion, !isTrackingLatticeState);
+
+ // TODO: Add support for non-symbol callables when necessary. If the callable
+ // has non-call uses we would mark as having reached pessimistic fixpoint,
+ // otherwise allow for propagating the return values out.
+ markAllPessimisticFixpoint(op, op->getResults());
+}
+
+void ForwardDataFlowSolver::visitCallOperation(CallOpInterface op) {
+ ResultRange callResults = op->getResults();
+
+ // Resolve the callable operation for this call.
+ Operation *callableOp = nullptr;
+ if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>())
+ callableOp = callableValue.getDefiningOp();
+ else
+ callableOp = callToSymbolCallable.lookup(op);
+
+ // The callable of this call can't be resolved, mark any results overdefined.
+ if (!callableOp)
+ return markAllPessimisticFixpoint(op, callResults);
+
+ // If this callable is tracking state, merge the argument operands with the
+ // arguments of the callable.
+ auto callableLatticeIt = callableLatticeState.find(callableOp);
+ if (callableLatticeIt == callableLatticeState.end())
+ return markAllPessimisticFixpoint(op, callResults);
+
+ OperandRange callOperands = op.getArgOperands();
+ auto callableArgs = callableLatticeIt->second.getCallableArguments();
+ for (auto it : llvm::zip(callOperands, callableArgs)) {
+ BlockArgument callableArg = std::get<1>(it);
+ AbstractLatticeElement &argValue = analysis.getLatticeElement(callableArg);
+ AbstractLatticeElement &operandValue =
+ analysis.getLatticeElement(std::get<0>(it));
+ if (argValue.join(operandValue) == ChangeResult::Change)
+ visitUsers(callableArg);
+ }
+
+ // Visit the callable.
+ visitCallableOperation(callableOp);
+
+ // Merge in the lattice state for the callable results as well.
+ auto callableResults = callableLatticeIt->second.getResultLatticeElements();
+ for (auto it : llvm::zip(callResults, callableResults))
+ join(/*owner=*/op,
+ /*to=*/analysis.getLatticeElement(std::get<0>(it)),
+ /*from=*/std::get<1>(it));
+}
+
+void ForwardDataFlowSolver::visitRegionBranchOperation(
+ RegionBranchOpInterface branch,
+ ArrayRef<AbstractLatticeElement *> operandLattices) {
+ // Check to see which regions are executable.
+ SmallVector<RegionSuccessor, 1> successors;
+ analysis.getSuccessorsForOperands(branch, /*sourceIndex=*/llvm::None,
+ operandLattices, successors);
+
+ // If the interface identified that no region will be executed. Mark
+ // any results of this operation as overdefined, as we can't reason about
+ // them.
+ // TODO: If we had an interface to detect pass through operands, we could
+ // resolve some results based on the lattice state of the operands. We could
+ // also allow for the parent operation to have itself as a region successor.
+ if (successors.empty())
+ return markAllPessimisticFixpoint(branch, branch->getResults());
+ return visitRegionSuccessors(
+ branch, successors, [&](Optional<unsigned> index) {
+ assert(index && "expected valid region index");
+ return branch.getSuccessorEntryOperands(*index);
+ });
+}
+
+void ForwardDataFlowSolver::visitRegionSuccessors(
+ Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
+ function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
+ for (const RegionSuccessor &it : regionSuccessors) {
+ Region *region = it.getSuccessor();
+ ValueRange succArgs = it.getSuccessorInputs();
+
+ // Check to see if this is the parent operation.
+ if (!region) {
+ ResultRange results = parentOp->getResults();
+ if (llvm::all_of(results, [&](Value res) { return isAtFixpoint(res); }))
+ continue;
+
+ // Mark the results outside of the input range as having reached the
+ // pessimistic fixpoint.
+ // TODO: This isn't exactly ideal. There may be situations in which a
+ // region operation can provide information for certain results that
+ // aren't part of the control flow.
+ if (succArgs.size() != results.size()) {
+ opWorklist.push_back(parentOp);
+ if (succArgs.empty())
+ return markAllPessimisticFixpoint(results);
+
+ unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber();
+ markAllPessimisticFixpoint(results.take_front(firstResIdx));
+ markAllPessimisticFixpoint(
+ results.drop_front(firstResIdx + succArgs.size()));
+ }
+
+ // Update the lattice for any operation results.
+ OperandRange operands = getInputsForRegion(/*index=*/llvm::None);
+ for (auto it : llvm::zip(succArgs, operands))
+ join(parentOp, analysis.getLatticeElement(std::get<0>(it)),
+ analysis.getLatticeElement(std::get<1>(it)));
+ return;
+ }
+ assert(!region->empty() && "expected region to be non-empty");
+ Block *entryBlock = ®ion->front();
+ markBlockExecutable(entryBlock);
+
+ // If all of the arguments have already reached a fixpoint, the arguments
+ // have already been fully resolved.
+ Block::BlockArgListType arguments = entryBlock->getArguments();
+ if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); }))
+ continue;
+
+ // Mark any arguments that do not receive inputs as having reached a
+ // pessimistic fixpoint, we won't be able to discern if they are constant.
+ // TODO: This isn't exactly ideal. There may be situations in which a
+ // region operation can provide information for certain results that
+ // aren't part of the control flow.
+ if (succArgs.size() != arguments.size()) {
+ if (succArgs.empty()) {
+ markAllPessimisticFixpoint(arguments);
+ continue;
+ }
+
+ unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber();
+ markAllPessimisticFixpointAndVisitUsers(
+ arguments.take_front(firstArgIdx));
+ markAllPessimisticFixpointAndVisitUsers(
+ arguments.drop_front(firstArgIdx + succArgs.size()));
+ }
+
+ // Update the lattice of arguments that have inputs from the predecessor.
+ OperandRange succOperands = getInputsForRegion(region->getRegionNumber());
+ for (auto it : llvm::zip(succArgs, succOperands)) {
+ AbstractLatticeElement &argValue =
+ analysis.getLatticeElement(std::get<0>(it));
+ AbstractLatticeElement &operandValue =
+ analysis.getLatticeElement(std::get<1>(it));
+ if (argValue.join(operandValue) == ChangeResult::Change)
+ visitUsers(std::get<0>(it));
+ }
+ }
+}
+
+void ForwardDataFlowSolver::visitTerminatorOperation(
+ Operation *op, ArrayRef<AbstractLatticeElement *> operandLattices) {
+ // If this operation has no successors, we treat it as an exiting terminator.
+ if (op->getNumSuccessors() == 0) {
+ Region *parentRegion = op->getParentRegion();
+ Operation *parentOp = parentRegion->getParentOp();
+
+ // Check to see if this is a terminator for a callable region.
+ if (isa<CallableOpInterface>(parentOp))
+ return visitCallableTerminatorOperation(parentOp, op, operandLattices);
+
+ // Otherwise, check to see if the parent tracks region control flow.
+ auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
+ if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
+ return;
+
+ // Query the set of successors of the current region using the current
+ // optimistic lattice state.
+ SmallVector<RegionSuccessor, 1> regionSuccessors;
+ analysis.getSuccessorsForOperands(regionInterface,
+ parentRegion->getRegionNumber(),
+ operandLattices, regionSuccessors);
+ if (regionSuccessors.empty())
+ return;
+
+ // If this terminator is not "region-like", conservatively mark all of the
+ // successor values as having reached the pessimistic fixpoint.
+ if (!op->hasTrait<OpTrait::ReturnLike>()) {
+ for (auto &it : regionSuccessors)
+ markAllPessimisticFixpointAndVisitUsers(it.getSuccessorInputs());
+ return;
+ }
+
+ // Otherwise, propagate the operand lattice states to the successors.
+ OperandRange operands = op->getOperands();
+ return visitRegionSuccessors(parentOp, regionSuccessors,
+ [&](Optional<unsigned>) { return operands; });
+ }
+
+ // Try to resolve to a specific set of successors with the current optimistic
+ // lattice state.
+ Block *block = op->getBlock();
+ if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+ SmallVector<Block *> successors;
+ if (succeeded(analysis.getSuccessorsForOperands(branch, operandLattices,
+ successors))) {
+ for (Block *succ : successors)
+ markEdgeExecutable(block, succ);
+ return;
+ }
+ }
+
+ // Otherwise, conservatively treat all edges as executable.
+ for (Block *succ : op->getSuccessors())
+ markEdgeExecutable(block, succ);
+}
+
+void ForwardDataFlowSolver::visitCallableTerminatorOperation(
+ Operation *callable, Operation *terminator,
+ ArrayRef<AbstractLatticeElement *> operandLattices) {
+ // If there are no exiting values, we have nothing to track.
+ if (terminator->getNumOperands() == 0)
+ return;
+
+ // If this callable isn't tracking any lattice state there is nothing to do.
+ auto latticeIt = callableLatticeState.find(callable);
+ if (latticeIt == callableLatticeState.end())
+ return;
+ assert(callable->getNumResults() == 0 && "expected symbol callable");
+
+ // If this terminator is not "return-like", conservatively mark all of the
+ // call-site results as having reached the pessimistic fixpoint.
+ auto callableResultLattices = latticeIt->second.getResultLatticeElements();
+ if (!terminator->hasTrait<OpTrait::ReturnLike>()) {
+ for (auto &it : callableResultLattices)
+ it.markPessimisticFixpoint();
+ for (Operation *call : latticeIt->second.getSymbolCalls())
+ markAllPessimisticFixpoint(call, call->getResults());
+ return;
+ }
+
+ // Merge the lattice state for terminator operands into the results.
+ ChangeResult result = ChangeResult::NoChange;
+ for (auto it : llvm::zip(operandLattices, callableResultLattices))
+ result |= std::get<1>(it).join(*std::get<0>(it));
+ if (result == ChangeResult::NoChange)
+ return;
+
+ // If any of the result lattices changed, update the callers.
+ for (Operation *call : latticeIt->second.getSymbolCalls())
+ for (auto it : llvm::zip(call->getResults(), callableResultLattices))
+ join(call, analysis.getLatticeElement(std::get<0>(it)), std::get<1>(it));
+}
+
+void ForwardDataFlowSolver::visitBlock(Block *block) {
+ // If the block is not the entry block we need to compute the lattice state
+ // for the block arguments. Entry block argument lattices are computed
+ // elsewhere, such as when visiting the parent operation.
+ if (!block->isEntryBlock()) {
+ for (int i : llvm::seq<int>(0, block->getNumArguments()))
+ visitBlockArgument(block, i);
+ }
+
+ // Visit all of the operations within the block.
+ for (Operation &op : *block)
+ visitOperation(&op);
+}
+
+void ForwardDataFlowSolver::visitBlockArgument(Block *block, int i) {
+ BlockArgument arg = block->getArgument(i);
+ AbstractLatticeElement &argLattice = analysis.getLatticeElement(arg);
+ if (argLattice.isAtFixpoint())
+ return;
+
+ ChangeResult updatedLattice = ChangeResult::NoChange;
+ for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
+ Block *pred = *it;
+
+ // We only care about this predecessor if it is going to execute.
+ if (!isEdgeExecutable(pred, block))
+ continue;
+
+ // Try to get the operand forwarded by the predecessor. If we can't reason
+ // about the terminator of the predecessor, mark as having reached a
+ // fixpoint.
+ Optional<OperandRange> branchOperands;
+ if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()))
+ branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex());
+ if (!branchOperands) {
+ updatedLattice |= argLattice.markPessimisticFixpoint();
+ break;
+ }
+
+ // If the operand hasn't been resolved, it is uninitialized and can merge
+ // with anything.
+ AbstractLatticeElement *operandLattice =
+ analysis.lookupLatticeElement((*branchOperands)[i]);
+ if (!operandLattice)
+ continue;
+
+ // Otherwise, join the operand lattice into the argument lattice.
+ updatedLattice |= argLattice.join(*operandLattice);
+ if (argLattice.isAtFixpoint())
+ break;
+ }
+
+ // If the lattice changed, visit users of the argument.
+ if (updatedLattice == ChangeResult::Change)
+ visitUsers(arg);
+}
+
+ChangeResult
+ForwardDataFlowSolver::markEntryBlockExecutable(Region *region,
+ bool markPessimisticFixpoint) {
+ if (!region->empty()) {
+ if (markPessimisticFixpoint)
+ markAllPessimisticFixpoint(region->front().getArguments());
+ return markBlockExecutable(®ion->front());
+ }
+ return ChangeResult::NoChange;
+}
+
+ChangeResult ForwardDataFlowSolver::markBlockExecutable(Block *block) {
+ bool marked = executableBlocks.insert(block).second;
+ if (marked)
+ blockWorklist.push_back(block);
+ return marked ? ChangeResult::Change : ChangeResult::NoChange;
+}
+
+bool ForwardDataFlowSolver::isBlockExecutable(Block *block) const {
+ return executableBlocks.count(block);
+}
+
+void ForwardDataFlowSolver::markEdgeExecutable(Block *from, Block *to) {
+ if (!executableEdges.insert(std::make_pair(from, to)).second)
+ return;
+
+ // Mark the destination as executable, and reprocess its arguments if it was
+ // already executable.
+ if (markBlockExecutable(to) == ChangeResult::NoChange) {
+ for (int i : llvm::seq<int>(0, to->getNumArguments()))
+ visitBlockArgument(to, i);
+ }
+}
+
+bool ForwardDataFlowSolver::isEdgeExecutable(Block *from, Block *to) const {
+ return executableEdges.count(std::make_pair(from, to));
+}
+
+void ForwardDataFlowSolver::markPessimisticFixpoint(Value value) {
+ analysis.getLatticeElement(value).markPessimisticFixpoint();
+}
+
+bool ForwardDataFlowSolver::isAtFixpoint(Value value) const {
+ if (auto *lattice = analysis.lookupLatticeElement(value))
+ return lattice->isAtFixpoint();
+ return false;
+}
+
+void ForwardDataFlowSolver::join(Operation *owner, AbstractLatticeElement &to,
+ const AbstractLatticeElement &from) {
+ if (to.join(from) == ChangeResult::Change)
+ opWorklist.push_back(owner);
+}
+
+//===----------------------------------------------------------------------===//
+// AbstractLatticeElement
+//===----------------------------------------------------------------------===//
+
+AbstractLatticeElement::~AbstractLatticeElement() {}
+
+//===----------------------------------------------------------------------===//
+// ForwardDataFlowAnalysisBase
+//===----------------------------------------------------------------------===//
+
+ForwardDataFlowAnalysisBase::~ForwardDataFlowAnalysisBase() {}
+
+AbstractLatticeElement &
+ForwardDataFlowAnalysisBase::getLatticeElement(Value value) {
+ AbstractLatticeElement *&latticeValue = latticeValues[value];
+ if (!latticeValue)
+ latticeValue = createLatticeElement(value);
+ return *latticeValue;
+}
+
+AbstractLatticeElement *
+ForwardDataFlowAnalysisBase::lookupLatticeElement(Value value) {
+ return latticeValues.lookup(value);
+}
+
+void ForwardDataFlowAnalysisBase::run(Operation *topLevelOp) {
+ // Run the main dataflow solver.
+ ForwardDataFlowSolver solver(*this, topLevelOp);
+ solver.solve();
+
+ // Any values that are still uninitialized now go to a pessimistic fixpoint,
+ // otherwise we assume an optimistic fixpoint has been reached.
+ for (auto &it : latticeValues)
+ if (it.second->isUninitialized())
+ it.second->markPessimisticFixpoint();
+ else
+ it.second->markOptimisticFixpoint();
+}
diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index a09403510da24..3d725d77db8ec 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -15,6 +15,7 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
+#include "mlir/Analysis/DataFlowAnalysis.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
@@ -25,326 +26,173 @@
using namespace mlir;
+//===----------------------------------------------------------------------===//
+// SCCP Analysis
+//===----------------------------------------------------------------------===//
+
namespace {
-/// This class represents a single lattice value. A lattive value corresponds to
-/// the various
diff erent states that a value in the SCCP dataflow analysis can
-/// take. See 'Kind' below for more details on the
diff erent states a value can
-/// take.
-class LatticeValue {
- enum Kind {
- /// A value with a yet to be determined value. This state may be changed to
- /// anything.
- Unknown,
-
- /// A value that is known to be a constant. This state may be changed to
- /// overdefined.
- Constant,
-
- /// A value that cannot statically be determined to be a constant. This
- /// state cannot be changed.
- Overdefined
- };
+struct SCCPLatticeValue {
+ SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr)
+ : constant(constant), constantDialect(dialect) {}
-public:
- /// Initialize a lattice value with "Unknown".
- LatticeValue()
- : constantAndTag(nullptr, Kind::Unknown), constantDialect(nullptr) {}
- /// Initialize a lattice value with a constant.
- LatticeValue(Attribute attr, Dialect *dialect)
- : constantAndTag(attr, Kind::Constant), constantDialect(dialect) {}
-
- /// Returns true if this lattice value is unknown.
- bool isUnknown() const { return constantAndTag.getInt() == Kind::Unknown; }
-
- /// Mark the lattice value as overdefined.
- void markOverdefined() {
- constantAndTag.setPointerAndInt(nullptr, Kind::Overdefined);
- constantDialect = nullptr;
+ /// The pessimistic state of SCCP is non-constant.
+ static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) {
+ return SCCPLatticeValue();
}
-
- /// Returns true if the lattice is overdefined.
- bool isOverdefined() const {
- return constantAndTag.getInt() == Kind::Overdefined;
+ static SCCPLatticeValue getPessimisticValueState(Value value) {
+ return SCCPLatticeValue();
}
- /// Mark the lattice value as constant.
- void markConstant(Attribute value, Dialect *dialect) {
- constantAndTag.setPointerAndInt(value, Kind::Constant);
- constantDialect = dialect;
+ /// Equivalence for SCCP only accounts for the constant, not the originating
+ /// dialect.
+ bool operator==(const SCCPLatticeValue &rhs) const {
+ return constant == rhs.constant;
}
- /// If this lattice is constant, return the constant. Returns nullptr
- /// otherwise.
- Attribute getConstant() const { return constantAndTag.getPointer(); }
-
- /// If this lattice is constant, return the dialect to use when materializing
- /// the constant.
- Dialect *getConstantDialect() const {
- assert(getConstant() && "expected valid constant");
- return constantDialect;
- }
-
- /// Merge in the value of the 'rhs' lattice into this one. Returns true if the
- /// lattice value changed.
- bool meet(const LatticeValue &rhs) {
- // If we are already overdefined, or rhs is unknown, there is nothing to do.
- if (isOverdefined() || rhs.isUnknown())
- return false;
- // If we are unknown, just take the value of rhs.
- if (isUnknown()) {
- constantAndTag = rhs.constantAndTag;
- constantDialect = rhs.constantDialect;
- return true;
- }
-
- // Otherwise, if this value doesn't match rhs go straight to overdefined.
- if (constantAndTag != rhs.constantAndTag) {
- markOverdefined();
- return true;
- }
- return false;
+ /// To join the state of two values, we simply check for equivalence.
+ static SCCPLatticeValue join(const SCCPLatticeValue &lhs,
+ const SCCPLatticeValue &rhs) {
+ return lhs == rhs ? lhs : SCCPLatticeValue();
}
-private:
- /// The attribute value if this is a constant and the tag for the element
- /// kind.
- llvm::PointerIntPair<Attribute, 2, Kind> constantAndTag;
+ /// The constant attribute value.
+ Attribute constant;
- /// The dialect the constant originated from. This is only valid if the
- /// lattice is a constant. This is not used as part of the key, and is only
- /// needed to materialize the held constant if necessary.
+ /// The dialect the constant originated from. This is not used as part of the
+ /// key, and is only needed to materialize the held constant if necessary.
Dialect *constantDialect;
};
-/// This class contains various state used when computing the lattice of a
-/// callable operation.
-class CallableLatticeState {
-public:
- /// Build a lattice state with a given callable region, and a specified number
- /// of results to be initialized to the default lattice value (Unknown).
- CallableLatticeState(Region *callableRegion, unsigned numResults)
- : callableArguments(callableRegion->getArguments()),
- resultLatticeValues(numResults) {}
-
- /// Returns the arguments to the callable region.
- Block::BlockArgListType getCallableArguments() const {
- return callableArguments;
- }
-
- /// Returns the lattice value for the results of the callable region.
- MutableArrayRef<LatticeValue> getResultLatticeValues() {
- return resultLatticeValues;
- }
-
- /// Add a call to this callable. This is only used if the callable defines a
- /// symbol.
- void addSymbolCall(Operation *op) { symbolCalls.push_back(op); }
-
- /// Return the calls that reference this callable. This is only used
- /// if the callable defines a symbol.
- ArrayRef<Operation *> getSymbolCalls() const { return symbolCalls; }
-
-private:
- /// The arguments of the callable region.
- Block::BlockArgListType callableArguments;
+struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> {
+ using ForwardDataFlowAnalysis<SCCPLatticeValue>::ForwardDataFlowAnalysis;
+ ~SCCPAnalysis() override = default;
- /// The lattice state for each of the results of this region. The return
- /// values of the callable aren't SSA values, so we need to track them
- /// separately.
- SmallVector<LatticeValue, 4> resultLatticeValues;
-
- /// The calls referencing this callable if this callable defines a symbol.
- /// This removes the need to recompute symbol references during propagation.
- /// Value based references are trivial to resolve, so they can be done
- /// in-place.
- SmallVector<Operation *, 4> symbolCalls;
-};
-
-/// This class represents the solver for the SCCP analysis. This class acts as
-/// the propagation engine for computing which values form constants.
-class SCCPSolver {
-public:
- /// Initialize the solver with the given top-level operation.
- SCCPSolver(Operation *op);
-
- /// Run the solver until it converges.
- void solve();
-
- /// Rewrite the given regions using the computing analysis. This replaces the
- /// uses of all values that have been computed to be constant, and erases as
- /// many newly dead operations.
- void rewrite(MLIRContext *context, MutableArrayRef<Region> regions);
-
-private:
- /// Initialize the set of symbol defining callables that can have their
- /// arguments and results tracked. 'op' is the top-level operation that SCCP
- /// is operating on.
- void initializeSymbolCallables(Operation *op);
-
- /// Replace the given value with a constant if the corresponding lattice
- /// represents a constant. Returns success if the value was replaced, failure
- /// otherwise.
- LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder,
- Value value);
-
- /// Visit the users of the given IR that reside within executable blocks.
- template <typename T>
- void visitUsers(T &value) {
- for (Operation *user : value.getUsers())
- if (isBlockExecutable(user->getBlock()))
- visitOperation(user);
- }
+ ChangeResult
+ visitOperation(Operation *op,
+ ArrayRef<LatticeElement<SCCPLatticeValue> *> operands) final {
+ // Don't try to simulate the results of a region operation as we can't
+ // guarantee that folding will be out-of-place. We don't allow in-place
+ // folds as the desire here is for simulated execution, and not general
+ // folding.
+ if (op->getNumRegions())
+ return markAllPessimisticFixpoint(op->getResults());
+
+ SmallVector<Attribute> constantOperands(
+ llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
+ return value->getValue().constant;
+ }));
+
+ // Save the original operands and attributes just in case the operation
+ // folds in-place. The constant passed in may not correspond to the real
+ // runtime value, so in-place updates are not allowed.
+ SmallVector<Value, 8> originalOperands(op->getOperands());
+ DictionaryAttr originalAttrs = op->getAttrDictionary();
+
+ // Simulate the result of folding this operation to a constant. If folding
+ // fails or was an in-place fold, mark the results as overdefined.
+ SmallVector<OpFoldResult, 8> foldResults;
+ foldResults.reserve(op->getNumResults());
+ if (failed(op->fold(constantOperands, foldResults)))
+ return markAllPessimisticFixpoint(op->getResults());
+
+ // If the folding was in-place, mark the results as overdefined and reset
+ // the operation. We don't allow in-place folds as the desire here is for
+ // simulated execution, and not general folding.
+ if (foldResults.empty()) {
+ op->setOperands(originalOperands);
+ op->setAttrs(originalAttrs);
+ return markAllPessimisticFixpoint(op->getResults());
+ }
- /// Visit the given operation and compute any necessary lattice state.
- void visitOperation(Operation *op);
-
- /// Visit the given call operation and compute any necessary lattice state.
- void visitCallOperation(CallOpInterface op);
-
- /// Visit the given callable operation and compute any necessary lattice
- /// state.
- void visitCallableOperation(Operation *op);
-
- /// Visit the given operation, which defines regions, and compute any
- /// necessary lattice state. This also resolves the lattice state of both the
- /// operation results and any nested regions.
- void visitRegionOperation(Operation *op,
- ArrayRef<Attribute> constantOperands);
-
- /// Visit the given set of region successors, computing any necessary lattice
- /// state. The provided function returns the input operands to the region at
- /// the given index. If the index is 'None', the input operands correspond to
- /// the parent operation results.
- void visitRegionSuccessors(
- Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
- function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion);
-
- /// Visit the given terminator operation and compute any necessary lattice
- /// state.
- void visitTerminatorOperation(Operation *op,
- ArrayRef<Attribute> constantOperands);
-
- /// Visit the given terminator operation that exits a callable region. These
- /// are terminators with no CFG successors.
- void visitCallableTerminatorOperation(Operation *callable,
- Operation *terminator);
-
- /// Visit the given block and compute any necessary lattice state.
- void visitBlock(Block *block);
-
- /// Visit argument #'i' of the given block and compute any necessary lattice
- /// state.
- void visitBlockArgument(Block *block, int i);
-
- /// Mark the entry block of the given region as executable. Returns false if
- /// the block was already marked executable. If `markArgsOverdefined` is true,
- /// the arguments of the entry block are also set to overdefined.
- bool markEntryBlockExecutable(Region *region, bool markArgsOverdefined);
-
- /// Mark the given block as executable. Returns false if the block was already
- /// marked executable.
- bool markBlockExecutable(Block *block);
-
- /// Returns true if the given block is executable.
- bool isBlockExecutable(Block *block) const;
-
- /// Mark the edge between 'from' and 'to' as executable.
- void markEdgeExecutable(Block *from, Block *to);
-
- /// Return true if the edge between 'from' and 'to' is executable.
- bool isEdgeExecutable(Block *from, Block *to) const;
-
- /// Mark the given value as overdefined. This means that we cannot refine a
- /// specific constant for this value.
- void markOverdefined(Value value);
-
- /// Mark all of the given values as overdefined.
- template <typename ValuesT>
- void markAllOverdefined(ValuesT values) {
- for (auto value : values)
- markOverdefined(value);
- }
- template <typename ValuesT>
- void markAllOverdefined(Operation *op, ValuesT values) {
- markAllOverdefined(values);
- opWorklist.push_back(op);
- }
- template <typename ValuesT>
- void markAllOverdefinedAndVisitUsers(ValuesT values) {
- for (auto value : values) {
- auto &lattice = latticeValues[value];
- if (!lattice.isOverdefined()) {
- lattice.markOverdefined();
- visitUsers(value);
- }
+ // Merge the fold results into the lattice for this operation.
+ assert(foldResults.size() == op->getNumResults() && "invalid result size");
+ Dialect *dialect = op->getDialect();
+ ChangeResult result = ChangeResult::NoChange;
+ for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
+ LatticeElement<SCCPLatticeValue> &lattice =
+ getLatticeElement(op->getResult(i));
+
+ // Merge in the result of the fold, either a constant or a value.
+ OpFoldResult foldResult = foldResults[i];
+ if (Attribute attr = foldResult.dyn_cast<Attribute>())
+ result |= lattice.join(SCCPLatticeValue(attr, dialect));
+ else
+ result |= lattice.join(getLatticeElement(foldResult.get<Value>()));
+ }
+ return result;
+ }
+
+ /// Implementation of `getSuccessorsForOperands` that uses constant operands
+ /// to potentially remove dead successors.
+ LogicalResult getSuccessorsForOperands(
+ BranchOpInterface branch,
+ ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
+ SmallVectorImpl<Block *> &successors) final {
+ SmallVector<Attribute> constantOperands(
+ llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
+ return value->getValue().constant;
+ }));
+ if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
+ successors.push_back(singleSucc);
+ return success();
}
+ return failure();
}
- /// Returns true if the given value was marked as overdefined.
- bool isOverdefined(Value value) const;
-
- /// Merge in the given lattice 'from' into the lattice 'to'. 'owner'
- /// corresponds to the parent operation of 'to'.
- void meet(Operation *owner, LatticeValue &to, const LatticeValue &from);
-
- /// The lattice for each SSA value.
- DenseMap<Value, LatticeValue> latticeValues;
-
- /// The set of blocks that are known to execute, or are intrinsically live.
- SmallPtrSet<Block *, 16> executableBlocks;
-
- /// The set of control flow edges that are known to execute.
- DenseSet<std::pair<Block *, Block *>> executableEdges;
-
- /// A worklist containing blocks that need to be processed.
- SmallVector<Block *, 64> blockWorklist;
-
- /// A worklist of operations that need to be processed.
- SmallVector<Operation *, 64> opWorklist;
-
- /// The callable operations that have their argument/result state tracked.
- DenseMap<Operation *, CallableLatticeState> callableLatticeState;
-
- /// A map between a call operation and the resolved symbol callable. This
- /// avoids re-resolving symbol references during propagation. Value based
- /// callables are trivial to resolve, so they can be done in-place.
- DenseMap<Operation *, Operation *> callToSymbolCallable;
-
- /// A symbol table used for O(1) symbol lookups during simplification.
- SymbolTableCollection symbolTable;
+ /// Implementation of `getSuccessorsForOperands` that uses constant operands
+ /// to potentially remove dead region successors.
+ void getSuccessorsForOperands(
+ RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
+ ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
+ SmallVectorImpl<RegionSuccessor> &successors) final {
+ SmallVector<Attribute> constantOperands(
+ llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
+ return value->getValue().constant;
+ }));
+ branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
+ }
};
-} // end anonymous namespace
+} // namespace
-SCCPSolver::SCCPSolver(Operation *op) {
- /// Initialize the solver with the regions within this operation.
- for (Region ®ion : op->getRegions()) {
- // Mark the entry block as executable. The values passed to these regions
- // are also invisible, so mark any arguments as overdefined.
- markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
- }
- initializeSymbolCallables(op);
-}
+//===----------------------------------------------------------------------===//
+// SCCP Rewrites
+//===----------------------------------------------------------------------===//
-void SCCPSolver::solve() {
- while (!blockWorklist.empty() || !opWorklist.empty()) {
- // Process any operations in the op worklist.
- while (!opWorklist.empty())
- visitUsers(*opWorklist.pop_back_val());
+/// Replace the given value with a constant if the corresponding lattice
+/// represents a constant. Returns success if the value was replaced, failure
+/// otherwise.
+static LogicalResult replaceWithConstant(SCCPAnalysis &analysis,
+ OpBuilder &builder,
+ OperationFolder &folder, Value value) {
+ LatticeElement<SCCPLatticeValue> *lattice =
+ analysis.lookupLatticeElement(value);
+ if (!lattice)
+ return failure();
+ SCCPLatticeValue &latticeValue = lattice->getValue();
+ if (!latticeValue.constant)
+ return failure();
- // Process any blocks in the block worklist.
- while (!blockWorklist.empty())
- visitBlock(blockWorklist.pop_back_val());
- }
+ // Attempt to materialize a constant for the given value.
+ Dialect *dialect = latticeValue.constantDialect;
+ Value constant = folder.getOrCreateConstant(
+ builder, dialect, latticeValue.constant, value.getType(), value.getLoc());
+ if (!constant)
+ return failure();
+
+ value.replaceAllUsesWith(constant);
+ return success();
}
-void SCCPSolver::rewrite(MLIRContext *context,
- MutableArrayRef<Region> initialRegions) {
- SmallVector<Block *, 8> worklist;
+/// Rewrite the given regions using the computing analysis. This replaces the
+/// uses of all values that have been computed to be constant, and erases as
+/// many newly dead operations.
+static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
+ MutableArrayRef<Region> initialRegions) {
+ SmallVector<Block *> worklist;
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
for (Region ®ion : regions)
- for (Block &block : region)
- if (isBlockExecutable(&block))
- worklist.push_back(&block);
+ for (Block &block : llvm::reverse(region))
+ worklist.push_back(&block);
};
// An operation folder used to create and unique constants.
@@ -355,18 +203,14 @@ void SCCPSolver::rewrite(MLIRContext *context,
while (!worklist.empty()) {
Block *block = worklist.pop_back_val();
- // Replace any block arguments with constants.
- builder.setInsertionPointToStart(block);
- for (BlockArgument arg : block->getArguments())
- (void)replaceWithConstant(builder, folder, arg);
-
for (Operation &op : llvm::make_early_inc_range(*block)) {
builder.setInsertionPoint(&op);
// Replace any result with constants.
bool replacedAll = op.getNumResults() != 0;
for (Value res : op.getResults())
- replacedAll &= succeeded(replaceWithConstant(builder, folder, res));
+ replacedAll &=
+ succeeded(replaceWithConstant(analysis, builder, folder, res));
// If all of the results of the operation were replaced, try to erase
// the operation completely.
@@ -379,532 +223,14 @@ void SCCPSolver::rewrite(MLIRContext *context,
// Add any the regions of this operation to the worklist.
addToWorklist(op.getRegions());
}
- }
-}
-
-void SCCPSolver::initializeSymbolCallables(Operation *op) {
- // Initialize the set of symbol callables that can have their state tracked.
- // This tracks which symbol callable operations we can propagate within and
- // out of.
- auto walkFn = [&](Operation *symTable, bool allUsesVisible) {
- Region &symbolTableRegion = symTable->getRegion(0);
- Block *symbolTableBlock = &symbolTableRegion.front();
- for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) {
- // We won't be able to track external callables.
- Region *callableRegion = callable.getCallableRegion();
- if (!callableRegion)
- continue;
- // We only care about symbol defining callables here.
- auto symbol = dyn_cast<SymbolOpInterface>(callable.getOperation());
- if (!symbol)
- continue;
- callableLatticeState.try_emplace(callable, callableRegion,
- callable.getCallableResults().size());
-
- // If not all of the uses of this symbol are visible, we can't track the
- // state of the arguments.
- if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) {
- for (Region ®ion : callable->getRegions())
- markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
- }
- }
- if (callableLatticeState.empty())
- return;
-
- // After computing the valid callables, walk any symbol uses to check
- // for non-call references. We won't be able to track the lattice state
- // for arguments to these callables, as we can't guarantee that we can see
- // all of its calls.
- Optional<SymbolTable::UseRange> uses =
- SymbolTable::getSymbolUses(&symbolTableRegion);
- if (!uses) {
- // If we couldn't gather the symbol uses, conservatively assume that
- // we can't track information for any nested symbols.
- op->walk([&](CallableOpInterface op) { callableLatticeState.erase(op); });
- return;
- }
-
- for (const SymbolTable::SymbolUse &use : *uses) {
- // If the use is a call, track it to avoid the need to recompute the
- // reference later.
- if (auto callOp = dyn_cast<CallOpInterface>(use.getUser())) {
- Operation *symCallable = callOp.resolveCallable(&symbolTable);
- auto callableLatticeIt = callableLatticeState.find(symCallable);
- if (callableLatticeIt != callableLatticeState.end()) {
- callToSymbolCallable.try_emplace(callOp, symCallable);
-
- // We only need to record the call in the lattice if it produces any
- // values.
- if (callOp->getNumResults())
- callableLatticeIt->second.addSymbolCall(callOp);
- }
- continue;
- }
- // This use isn't a call, so don't we know all of the callers.
- auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef());
- auto it = callableLatticeState.find(symbol);
- if (it != callableLatticeState.end()) {
- for (Region ®ion : it->first->getRegions())
- markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
- }
- }
- };
- SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(),
- walkFn);
-}
-
-LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder,
- OperationFolder &folder,
- Value value) {
- auto it = latticeValues.find(value);
- auto attr = it == latticeValues.end() ? nullptr : it->second.getConstant();
- if (!attr)
- return failure();
-
- // Attempt to materialize a constant for the given value.
- Dialect *dialect = it->second.getConstantDialect();
- Value constant = folder.getOrCreateConstant(builder, dialect, attr,
- value.getType(), value.getLoc());
- if (!constant)
- return failure();
-
- value.replaceAllUsesWith(constant);
- latticeValues.erase(it);
- return success();
-}
-void SCCPSolver::visitOperation(Operation *op) {
- // Collect all of the constant operands feeding into this operation. If any
- // are not ready to be resolved, bail out and wait for them to resolve.
- SmallVector<Attribute, 8> operandConstants;
- operandConstants.reserve(op->getNumOperands());
- for (Value operand : op->getOperands()) {
- // Make sure all of the operands are resolved first.
- auto &operandLattice = latticeValues[operand];
- if (operandLattice.isUnknown())
- return;
- operandConstants.push_back(operandLattice.getConstant());
- }
-
- // If this is a terminator operation, process any control flow lattice state.
- if (op->hasTrait<OpTrait::IsTerminator>())
- visitTerminatorOperation(op, operandConstants);
-
- // Process call operations. The call visitor processes result values, so we
- // can exit afterwards.
- if (CallOpInterface call = dyn_cast<CallOpInterface>(op))
- return visitCallOperation(call);
-
- // Process callable operations. These are specially handled region operations
- // that track dataflow via calls.
- if (isa<CallableOpInterface>(op)) {
- // If this callable has a tracked lattice state, it will be visited by calls
- // that reference it instead. This way, we don't assume that it is
- // executable unless there is a proper reference to it.
- if (callableLatticeState.count(op))
- return;
- return visitCallableOperation(op);
- }
-
- // Process region holding operations. The region visitor processes result
- // values, so we can exit afterwards.
- if (op->getNumRegions())
- return visitRegionOperation(op, operandConstants);
-
- // If this op produces no results, it can't produce any constants.
- if (op->getNumResults() == 0)
- return;
-
- // If all of the results of this operation are already overdefined, bail out
- // early.
- auto isOverdefinedFn = [&](Value value) { return isOverdefined(value); };
- if (llvm::all_of(op->getResults(), isOverdefinedFn))
- return;
-
- // Save the original operands and attributes just in case the operation folds
- // in-place. The constant passed in may not correspond to the real runtime
- // value, so in-place updates are not allowed.
- SmallVector<Value, 8> originalOperands(op->getOperands());
- DictionaryAttr originalAttrs = op->getAttrDictionary();
-
- // Simulate the result of folding this operation to a constant. If folding
- // fails or was an in-place fold, mark the results as overdefined.
- SmallVector<OpFoldResult, 8> foldResults;
- foldResults.reserve(op->getNumResults());
- if (failed(op->fold(operandConstants, foldResults)))
- return markAllOverdefined(op, op->getResults());
-
- // If the folding was in-place, mark the results as overdefined and reset the
- // operation. We don't allow in-place folds as the desire here is for
- // simulated execution, and not general folding.
- if (foldResults.empty()) {
- op->setOperands(originalOperands);
- op->setAttrs(originalAttrs);
- return markAllOverdefined(op, op->getResults());
- }
-
- // Merge the fold results into the lattice for this operation.
- assert(foldResults.size() == op->getNumResults() && "invalid result size");
- Dialect *opDialect = op->getDialect();
- for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
- LatticeValue &resultLattice = latticeValues[op->getResult(i)];
-
- // Merge in the result of the fold, either a constant or a value.
- OpFoldResult foldResult = foldResults[i];
- if (Attribute foldAttr = foldResult.dyn_cast<Attribute>())
- meet(op, resultLattice, LatticeValue(foldAttr, opDialect));
- else
- meet(op, resultLattice, latticeValues[foldResult.get<Value>()]);
- }
-}
-
-void SCCPSolver::visitCallableOperation(Operation *op) {
- // Mark the regions as executable. If we aren't tracking lattice state for
- // this callable, mark all of the region arguments as overdefined.
- bool isTrackingLatticeState = callableLatticeState.count(op);
- for (Region ®ion : op->getRegions())
- markEntryBlockExecutable(®ion, !isTrackingLatticeState);
-
- // TODO: Add support for non-symbol callables when necessary. If the callable
- // has non-call uses we would mark overdefined, otherwise allow for
- // propagating the return values out.
- markAllOverdefined(op, op->getResults());
-}
-
-void SCCPSolver::visitCallOperation(CallOpInterface op) {
- ResultRange callResults = op->getResults();
-
- // Resolve the callable operation for this call.
- Operation *callableOp = nullptr;
- if (Value callableValue = op.getCallableForCallee().dyn_cast<Value>())
- callableOp = callableValue.getDefiningOp();
- else
- callableOp = callToSymbolCallable.lookup(op);
-
- // The callable of this call can't be resolved, mark any results overdefined.
- if (!callableOp)
- return markAllOverdefined(op, callResults);
-
- // If this callable is tracking state, merge the argument operands with the
- // arguments of the callable.
- auto callableLatticeIt = callableLatticeState.find(callableOp);
- if (callableLatticeIt == callableLatticeState.end())
- return markAllOverdefined(op, callResults);
-
- OperandRange callOperands = op.getArgOperands();
- auto callableArgs = callableLatticeIt->second.getCallableArguments();
- for (auto it : llvm::zip(callOperands, callableArgs)) {
- BlockArgument callableArg = std::get<1>(it);
- if (latticeValues[callableArg].meet(latticeValues[std::get<0>(it)]))
- visitUsers(callableArg);
- }
-
- // Visit the callable.
- visitCallableOperation(callableOp);
-
- // Merge in the lattice state for the callable results as well.
- auto callableResults = callableLatticeIt->second.getResultLatticeValues();
- for (auto it : llvm::zip(callResults, callableResults))
- meet(/*owner=*/op, /*to=*/latticeValues[std::get<0>(it)],
- /*from=*/std::get<1>(it));
-}
-
-void SCCPSolver::visitRegionOperation(Operation *op,
- ArrayRef<Attribute> constantOperands) {
- // Check to see if we can reason about the internal control flow of this
- // region operation.
- auto regionInterface = dyn_cast<RegionBranchOpInterface>(op);
- if (!regionInterface) {
- // If we can't, conservatively mark all regions as executable.
- for (Region ®ion : op->getRegions())
- markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true);
-
- // Don't try to simulate the results of a region operation as we can't
- // guarantee that folding will be out-of-place. We don't allow in-place
- // folds as the desire here is for simulated execution, and not general
- // folding.
- return markAllOverdefined(op, op->getResults());
- }
-
- // Check to see which regions are executable.
- SmallVector<RegionSuccessor, 1> successors;
- regionInterface.getSuccessorRegions(/*index=*/llvm::None, constantOperands,
- successors);
-
- // If the interface identified that no region will be executed. Mark
- // any results of this operation as overdefined, as we can't reason about
- // them.
- // TODO: If we had an interface to detect pass through operands, we could
- // resolve some results based on the lattice state of the operands. We could
- // also allow for the parent operation to have itself as a region successor.
- if (successors.empty())
- return markAllOverdefined(op, op->getResults());
- return visitRegionSuccessors(op, successors, [&](Optional<unsigned> index) {
- assert(index && "expected valid region index");
- return regionInterface.getSuccessorEntryOperands(*index);
- });
-}
-
-void SCCPSolver::visitRegionSuccessors(
- Operation *parentOp, ArrayRef<RegionSuccessor> regionSuccessors,
- function_ref<OperandRange(Optional<unsigned>)> getInputsForRegion) {
- for (const RegionSuccessor &it : regionSuccessors) {
- Region *region = it.getSuccessor();
- ValueRange succArgs = it.getSuccessorInputs();
-
- // Check to see if this is the parent operation.
- if (!region) {
- ResultRange results = parentOp->getResults();
- if (llvm::all_of(results, [&](Value res) { return isOverdefined(res); }))
- continue;
-
- // Mark the results outside of the input range as overdefined.
- if (succArgs.size() != results.size()) {
- opWorklist.push_back(parentOp);
- if (succArgs.empty())
- return markAllOverdefined(results);
-
- unsigned firstResIdx = succArgs[0].cast<OpResult>().getResultNumber();
- markAllOverdefined(results.take_front(firstResIdx));
- markAllOverdefined(results.drop_front(firstResIdx + succArgs.size()));
- }
-
- // Update the lattice for any operation results.
- OperandRange operands = getInputsForRegion(/*index=*/llvm::None);
- for (auto it : llvm::zip(succArgs, operands))
- meet(parentOp, latticeValues[std::get<0>(it)],
- latticeValues[std::get<1>(it)]);
- return;
- }
- assert(!region->empty() && "expected region to be non-empty");
- Block *entryBlock = ®ion->front();
- markBlockExecutable(entryBlock);
-
- // If all of the arguments are already overdefined, the arguments have
- // already been fully resolved.
- auto arguments = entryBlock->getArguments();
- if (llvm::all_of(arguments, [&](Value arg) { return isOverdefined(arg); }))
- continue;
-
- // Mark any arguments that do not receive inputs as overdefined, we won't be
- // able to discern if they are constant.
- if (succArgs.size() != arguments.size()) {
- if (succArgs.empty()) {
- markAllOverdefined(arguments);
- continue;
- }
-
- unsigned firstArgIdx = succArgs[0].cast<BlockArgument>().getArgNumber();
- markAllOverdefinedAndVisitUsers(arguments.take_front(firstArgIdx));
- markAllOverdefinedAndVisitUsers(
- arguments.drop_front(firstArgIdx + succArgs.size()));
- }
-
- // Update the lattice for arguments that have inputs from the predecessor.
- OperandRange succOperands = getInputsForRegion(region->getRegionNumber());
- for (auto it : llvm::zip(succArgs, succOperands)) {
- LatticeValue &argLattice = latticeValues[std::get<0>(it)];
- if (argLattice.meet(latticeValues[std::get<1>(it)]))
- visitUsers(std::get<0>(it));
- }
- }
-}
-
-void SCCPSolver::visitTerminatorOperation(
- Operation *op, ArrayRef<Attribute> constantOperands) {
- // If this operation has no successors, we treat it as an exiting terminator.
- if (op->getNumSuccessors() == 0) {
- Region *parentRegion = op->getParentRegion();
- Operation *parentOp = parentRegion->getParentOp();
-
- // Check to see if this is a terminator for a callable region.
- if (isa<CallableOpInterface>(parentOp))
- return visitCallableTerminatorOperation(parentOp, op);
-
- // Otherwise, check to see if the parent tracks region control flow.
- auto regionInterface = dyn_cast<RegionBranchOpInterface>(parentOp);
- if (!regionInterface || !isBlockExecutable(parentOp->getBlock()))
- return;
-
- // Query the set of successors from the current region.
- SmallVector<RegionSuccessor, 1> regionSuccessors;
- regionInterface.getSuccessorRegions(parentRegion->getRegionNumber(),
- constantOperands, regionSuccessors);
- if (regionSuccessors.empty())
- return;
-
- // If this terminator is not "region-like", conservatively mark all of the
- // successor values as overdefined.
- if (!op->hasTrait<OpTrait::ReturnLike>()) {
- for (auto &it : regionSuccessors)
- markAllOverdefinedAndVisitUsers(it.getSuccessorInputs());
- return;
- }
-
- // Otherwise, propagate the operand lattice states to each of the
- // successors.
- OperandRange operands = op->getOperands();
- return visitRegionSuccessors(parentOp, regionSuccessors,
- [&](Optional<unsigned>) { return operands; });
- }
-
- // Try to resolve to a specific successor with the constant operands.
- if (auto branch = dyn_cast<BranchOpInterface>(op)) {
- if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
- markEdgeExecutable(op->getBlock(), singleSucc);
- return;
- }
- }
-
- // Otherwise, conservatively treat all edges as executable.
- Block *block = op->getBlock();
- for (Block *succ : op->getSuccessors())
- markEdgeExecutable(block, succ);
-}
-
-void SCCPSolver::visitCallableTerminatorOperation(Operation *callable,
- Operation *terminator) {
- // If there are no exiting values, we have nothing to track.
- if (terminator->getNumOperands() == 0)
- return;
-
- // If this callable isn't tracking any lattice state there is nothing to do.
- auto latticeIt = callableLatticeState.find(callable);
- if (latticeIt == callableLatticeState.end())
- return;
- assert(callable->getNumResults() == 0 && "expected symbol callable");
-
- // If this terminator is not "return-like", conservatively mark all of the
- // call-site results as overdefined.
- auto callableResultLattices = latticeIt->second.getResultLatticeValues();
- if (!terminator->hasTrait<OpTrait::ReturnLike>()) {
- for (auto &it : callableResultLattices)
- it.markOverdefined();
- for (Operation *call : latticeIt->second.getSymbolCalls())
- markAllOverdefined(call, call->getResults());
- return;
- }
-
- // Merge the terminator operands into the results.
- bool anyChanged = false;
- for (auto it : llvm::zip(terminator->getOperands(), callableResultLattices))
- anyChanged |= std::get<1>(it).meet(latticeValues[std::get<0>(it)]);
- if (!anyChanged)
- return;
-
- // If any of the result lattices changed, update the callers.
- for (Operation *call : latticeIt->second.getSymbolCalls())
- for (auto it : llvm::zip(call->getResults(), callableResultLattices))
- meet(call, latticeValues[std::get<0>(it)], std::get<1>(it));
-}
-
-void SCCPSolver::visitBlock(Block *block) {
- // If the block is not the entry block we need to compute the lattice state
- // for the block arguments. Entry block argument lattices are computed
- // elsewhere, such as when visiting the parent operation.
- if (!block->isEntryBlock()) {
- for (int i : llvm::seq<int>(0, block->getNumArguments()))
- visitBlockArgument(block, i);
- }
-
- // Visit all of the operations within the block.
- for (Operation &op : *block)
- visitOperation(&op);
-}
-
-void SCCPSolver::visitBlockArgument(Block *block, int i) {
- BlockArgument arg = block->getArgument(i);
- LatticeValue &argLattice = latticeValues[arg];
- if (argLattice.isOverdefined())
- return;
-
- bool updatedLattice = false;
- for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
- Block *pred = *it;
-
- // We only care about this predecessor if it is going to execute.
- if (!isEdgeExecutable(pred, block))
- continue;
-
- // Try to get the operand forwarded by the predecessor. If we can't reason
- // about the terminator of the predecessor, mark overdefined.
- Optional<OperandRange> branchOperands;
- if (auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator()))
- branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex());
- if (!branchOperands) {
- updatedLattice = true;
- argLattice.markOverdefined();
- break;
- }
-
- // If the operand hasn't been resolved, it is unknown which can merge with
- // anything.
- auto operandLattice = latticeValues.find((*branchOperands)[i]);
- if (operandLattice == latticeValues.end())
- continue;
-
- // Otherwise, meet the two lattice values.
- updatedLattice |= argLattice.meet(operandLattice->second);
- if (argLattice.isOverdefined())
- break;
- }
-
- // If the lattice was updated, visit any executable users of the argument.
- if (updatedLattice)
- visitUsers(arg);
-}
-
-bool SCCPSolver::markEntryBlockExecutable(Region *region,
- bool markArgsOverdefined) {
- if (!region->empty()) {
- if (markArgsOverdefined)
- markAllOverdefined(region->front().getArguments());
- return markBlockExecutable(®ion->front());
- }
- return false;
-}
-
-bool SCCPSolver::markBlockExecutable(Block *block) {
- bool marked = executableBlocks.insert(block).second;
- if (marked)
- blockWorklist.push_back(block);
- return marked;
-}
-
-bool SCCPSolver::isBlockExecutable(Block *block) const {
- return executableBlocks.count(block);
-}
-
-void SCCPSolver::markEdgeExecutable(Block *from, Block *to) {
- if (!executableEdges.insert(std::make_pair(from, to)).second)
- return;
- // Mark the destination as executable, and reprocess its arguments if it was
- // already executable.
- if (!markBlockExecutable(to)) {
- for (int i : llvm::seq<int>(0, to->getNumArguments()))
- visitBlockArgument(to, i);
+ // Replace any block arguments with constants.
+ builder.setInsertionPointToStart(block);
+ for (BlockArgument arg : block->getArguments())
+ (void)replaceWithConstant(analysis, builder, folder, arg);
}
}
-bool SCCPSolver::isEdgeExecutable(Block *from, Block *to) const {
- return executableEdges.count(std::make_pair(from, to));
-}
-
-void SCCPSolver::markOverdefined(Value value) {
- latticeValues[value].markOverdefined();
-}
-
-bool SCCPSolver::isOverdefined(Value value) const {
- auto it = latticeValues.find(value);
- return it != latticeValues.end() && it->second.isOverdefined();
-}
-
-void SCCPSolver::meet(Operation *owner, LatticeValue &to,
- const LatticeValue &from) {
- if (to.meet(from))
- opWorklist.push_back(owner);
-}
-
//===----------------------------------------------------------------------===//
// SCCP Pass
//===----------------------------------------------------------------------===//
@@ -918,12 +244,9 @@ struct SCCP : public SCCPBase<SCCP> {
void SCCP::runOnOperation() {
Operation *op = getOperation();
- // Solve for SCCP constraints within nested regions.
- SCCPSolver solver(op);
- solver.solve();
-
- // Cleanup any operations using the solver analysis.
- solver.rewrite(&getContext(), op->getRegions());
+ SCCPAnalysis analysis(op->getContext());
+ analysis.run(op);
+ rewrite(analysis, op->getContext(), op->getRegions());
}
std::unique_ptr<Pass> mlir::createSCCPPass() {
More information about the Mlir-commits
mailing list