[llvm-branch-commits] [mlir] 086836c - [mlir] An implementation of sparse data-flow analysis

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jun 29 10:23:46 PDT 2022


Author: Mogball
Date: 2022-06-27T09:38:05-07:00
New Revision: 086836c794cc26f1070014003204c63fb3810702

URL: https://github.com/llvm/llvm-project/commit/086836c794cc26f1070014003204c63fb3810702
DIFF: https://github.com/llvm/llvm-project/commit/086836c794cc26f1070014003204c63fb3810702.diff

LOG: [mlir] An implementation of sparse data-flow analysis

This patch introduces a (forward) sparse data-flow analysis implemented with the data-flow analysis framework. The analysis interacts with liveness information that can be provided by dead-code analysis to be conditional. This patch re-implements SCCP using dead-code analysis and (conditional) constant propagation analyses.

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
    mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
    mlir/lib/Transforms/SCCP.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
index ee9392b387e60..5c8c3b8e2fab9 100644
--- a/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
+++ b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h
@@ -304,12 +304,17 @@ class PredecessorState : public AnalysisState {
     return knownPredecessors.getArrayRef();
   }
 
-  /// Add a known predecessor.
-  ChangeResult join(Operation *predecessor) {
-    return knownPredecessors.insert(predecessor) ? ChangeResult::Change
-                                                 : ChangeResult::NoChange;
+  /// Get the successor inputs from a predecessor.
+  ValueRange getSuccessorInputs(Operation *predecessor) const {
+    return successorInputs.lookup(predecessor);
   }
 
+  /// Add a known predecessor.
+  ChangeResult join(Operation *predecessor);
+
+  /// Add a known predecessor with successor inputs.
+  ChangeResult join(Operation *predecessor, ValueRange inputs);
+
 private:
   /// Whether all predecessors are known. Optimistically assume that we know
   /// all predecessors.
@@ -319,6 +324,9 @@ class PredecessorState : public AnalysisState {
   SetVector<Operation *, SmallVector<Operation *, 4>,
             SmallPtrSet<Operation *, 4>>
       knownPredecessors;
+
+  /// The successor inputs when branching from a given predecessor.
+  DenseMap<Operation *, ValueRange> successorInputs;
 };
 
 //===----------------------------------------------------------------------===//
@@ -413,6 +421,152 @@ class DeadCodeAnalysis : public DataFlowAnalysis {
   SymbolTableCollection symbolTable;
 };
 
+//===----------------------------------------------------------------------===//
+// AbstractSparseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Base class for sparse (forward) data-flow analyses. A sparse analysis
+/// implements a transfer function on operations from the lattices of the
+/// operands to the lattices of the results. This analysis will propagate
+/// lattices across control-flow edges and the callgraph using liveness
+/// information.
+class AbstractSparseDataFlowAnalysis : public DataFlowAnalysis {
+public:
+  /// Initialize the analysis by visiting every owner of an SSA value: all
+  /// operations and blocks.
+  LogicalResult initialize(Operation *top) override;
+
+  /// Visit a program point. If this is a block and all control-flow
+  /// predecessors or callsites are known, then the arguments lattices are
+  /// propagated from them. If this is a call operation or an operation with
+  /// region control-flow, then its result lattices are set accordingly.
+  /// Otherwise, the operation transfer function is invoked.
+  LogicalResult visit(ProgramPoint point) override;
+
+protected:
+  explicit AbstractSparseDataFlowAnalysis(DataFlowSolver &solver);
+
+  /// The operation transfer function. Given the operand lattices, this
+  /// function is expected to set the result lattices.
+  virtual void
+  visitOperationImpl(Operation *op,
+                     ArrayRef<const AbstractSparseLattice *> operandLattices,
+                     ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
+
+  /// Get the lattice element of a value.
+  virtual AbstractSparseLattice *getLatticeElement(Value value) = 0;
+
+  /// Get a read-only lattice element for a value and add it as a dependency to
+  /// a program point.
+  const AbstractSparseLattice *getLatticeElementFor(ProgramPoint point,
+                                                    Value value);
+
+  /// Mark a lattice element as having reached its pessimistic fixpoint and
+  /// propgate an update if changed.
+  void markPessimisticFixpoint(AbstractSparseLattice *lattice);
+
+  /// Mark the given lattice elements as having reached their pessimistic
+  /// fixpoints and propagate an update if any changed.
+  void markAllPessimisticFixpoint(ArrayRef<AbstractSparseLattice *> lattices);
+
+  /// Join the lattice element and propagate and update if it changed.
+  void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
+
+private:
+  /// Recursively initialize the analysis on nested operations and blocks.
+  LogicalResult initializeRecursively(Operation *op);
+
+  /// Visit an operation. If this is a call operation or an operation with
+  /// region control-flow, then its result lattices are set accordingly.
+  /// Otherwise, the operation transfer function is invoked.
+  void visitOperation(Operation *op);
+
+  /// If this is a block and all control-flow predecessors or callsites are
+  /// known, then the arguments lattices are propagated from them.
+  void visitBlock(Block *block);
+
+  /// Visit a program point `point` with predecessors within a region branch
+  /// operation `branch`, which can either be the entry block of one of the
+  /// regions or the parent operation itself, and set either the argument or
+  /// parent result lattices.
+  void visitRegionSuccessors(ProgramPoint point, RegionBranchOpInterface branch,
+                             Optional<unsigned> successorIndex,
+                             ArrayRef<AbstractSparseLattice *> lattices);
+};
+
+//===----------------------------------------------------------------------===//
+// SparseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+/// A sparse (forward) data-flow analysis for propagating SSA value lattices
+/// across the IR by implementing transfer functions for operations.
+///
+/// `StateT` is expected to be a subclass of `AbstractSparseLattice`.
+template <typename StateT>
+class SparseDataFlowAnalysis : public AbstractSparseDataFlowAnalysis {
+public:
+  explicit SparseDataFlowAnalysis(DataFlowSolver &solver)
+      : AbstractSparseDataFlowAnalysis(solver) {}
+
+  /// Visit an operation with the lattices of its operands. This function is
+  /// expected to set the lattices of the operation's results.
+  virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
+                              ArrayRef<StateT *> results) = 0;
+
+protected:
+  /// Get the lattice element for a value.
+  StateT *getLatticeElement(Value value) override {
+    return getOrCreate<StateT>(value);
+  }
+
+  /// Get the lattice element for a value and create a dependency on the
+  /// provided program point.
+  const StateT *getLatticeElementFor(ProgramPoint point, Value value) {
+    return static_cast<const StateT *>(
+        AbstractSparseDataFlowAnalysis::getLatticeElementFor(point, value));
+  }
+
+  /// Mark the lattice elements of a range of values as having reached their
+  /// pessimistic fixpoint.
+  void markAllPessimisticFixpoint(ArrayRef<StateT *> lattices) {
+    AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
+        {reinterpret_cast<AbstractSparseLattice *const *>(lattices.begin()),
+         lattices.size()});
+  }
+
+private:
+  /// Type-erased wrappers that convert the abstract lattice operands to derived
+  /// lattices and invoke the virtual hooks operating on the derived lattices.
+  void visitOperationImpl(
+      Operation *op, ArrayRef<const AbstractSparseLattice *> operandLattices,
+      ArrayRef<AbstractSparseLattice *> resultLattices) override {
+    visitOperation(
+        op,
+        {reinterpret_cast<const StateT *const *>(operandLattices.begin()),
+         operandLattices.size()},
+        {reinterpret_cast<StateT *const *>(resultLattices.begin()),
+         resultLattices.size()});
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// SparseConstantPropagation
+//===----------------------------------------------------------------------===//
+
+/// This analysis implements sparse constant propagation, which attempts to
+/// determine constant-valued results for operations using constant-valued
+/// operands, by speculatively folding operations. When combined with dead-code
+/// analysis, this becomes sparse conditional constant propagation (SCCP).
+class SparseConstantPropagation
+    : public SparseDataFlowAnalysis<Lattice<ConstantValue>> {
+public:
+  using SparseDataFlowAnalysis::SparseDataFlowAnalysis;
+
+  void visitOperation(Operation *op,
+                      ArrayRef<const Lattice<ConstantValue> *> operands,
+                      ArrayRef<Lattice<ConstantValue> *> results) override;
+};
+
 } // end namespace mlir
 
 #endif // MLIR_ANALYSIS_SPARSEDATAFLOWANALYSIS_H

diff  --git a/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
index 41ff8381b0749..80c1293d22729 100644
--- a/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
+++ b/mlir/lib/Analysis/SparseDataFlowAnalysis.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/SparseDataFlowAnalysis.h"
+#include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "dataflow"
 
@@ -77,6 +78,23 @@ void PredecessorState::print(raw_ostream &os) const {
     os << "  " << *op << "\n";
 }
 
+ChangeResult PredecessorState::join(Operation *predecessor) {
+  return knownPredecessors.insert(predecessor) ? ChangeResult::Change
+                                               : ChangeResult::NoChange;
+}
+
+ChangeResult PredecessorState::join(Operation *predecessor, ValueRange inputs) {
+  ChangeResult result = join(predecessor);
+  if (!inputs.empty()) {
+    ValueRange &curInputs = successorInputs[predecessor];
+    if (curInputs != inputs) {
+      curInputs = inputs;
+      result |= ChangeResult::Change;
+    }
+  }
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // CFGEdge
 //===----------------------------------------------------------------------===//
@@ -352,14 +370,18 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
   branch.getSuccessorRegions(/*index=*/{}, *operands, successors);
 
   for (const RegionSuccessor &successor : successors) {
+    // The successor can be either an entry block or the parent operation.
+    ProgramPoint point = successor.getSuccessor()
+                             ? &successor.getSuccessor()->front()
+                             : ProgramPoint(branch);
     // Mark the entry block as executable.
-    Region *region = successor.getSuccessor();
-    assert(region && "expected a region successor");
-    auto *state = getOrCreate<Executable>(&region->front());
+    auto *state = getOrCreate<Executable>(point);
     propagateIfChanged(state, state->setToLive());
     // Add the parent op as a predecessor.
-    auto *predecessors = getOrCreate<PredecessorState>(&region->front());
-    propagateIfChanged(predecessors, predecessors->join(branch));
+    auto *predecessors = getOrCreate<PredecessorState>(point);
+    propagateIfChanged(
+        predecessors,
+        predecessors->join(branch, successor.getSuccessorInputs()));
   }
 }
 
@@ -385,7 +407,8 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
       // Add this terminator as a predecessor to the parent op.
       predecessors = getOrCreate<PredecessorState>(branch);
     }
-    propagateIfChanged(predecessors, predecessors->join(op));
+    propagateIfChanged(predecessors,
+                       predecessors->join(op, successor.getSuccessorInputs()));
   }
 }
 
@@ -411,3 +434,336 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op,
     }
   }
 }
+
+//===----------------------------------------------------------------------===//
+// AbstractSparseDataFlowAnalysis
+//===----------------------------------------------------------------------===//
+
+AbstractSparseDataFlowAnalysis::AbstractSparseDataFlowAnalysis(
+    DataFlowSolver &solver)
+    : DataFlowAnalysis(solver) {
+  registerPointKind<CFGEdge>();
+}
+
+LogicalResult AbstractSparseDataFlowAnalysis::initialize(Operation *top) {
+  // Mark the entry block arguments as having reached their pessimistic
+  // fixpoints.
+  for (Region &region : top->getRegions()) {
+    if (region.empty())
+      continue;
+    for (Value argument : region.front().getArguments())
+      markPessimisticFixpoint(getLatticeElement(argument));
+  }
+
+  return initializeRecursively(top);
+}
+
+LogicalResult
+AbstractSparseDataFlowAnalysis::initializeRecursively(Operation *op) {
+  // Initialize the analysis by visiting every owner of an SSA value (all
+  // operations and blocks).
+  visitOperation(op);
+  for (Region &region : op->getRegions()) {
+    for (Block &block : region) {
+      getOrCreate<Executable>(&block)->blockContentSubscribe(this);
+      visitBlock(&block);
+      for (Operation &op : block)
+        if (failed(initializeRecursively(&op)))
+          return failure();
+    }
+  }
+
+  return success();
+}
+
+LogicalResult AbstractSparseDataFlowAnalysis::visit(ProgramPoint point) {
+  if (Operation *op = point.dyn_cast<Operation *>())
+    visitOperation(op);
+  else if (Block *block = point.dyn_cast<Block *>())
+    visitBlock(block);
+  else
+    return failure();
+  return success();
+}
+
+void AbstractSparseDataFlowAnalysis::visitOperation(Operation *op) {
+  // Exit early on operations with no results.
+  if (op->getNumResults() == 0)
+    return;
+
+  // If the containing block is not executable, bail out.
+  if (!getOrCreate<Executable>(op->getBlock())->isLive())
+    return;
+
+  // Get the result lattices.
+  SmallVector<AbstractSparseLattice *> resultLattices;
+  resultLattices.reserve(op->getNumResults());
+  // Track whether all results have reached their fixpoint.
+  bool allAtFixpoint = true;
+  for (Value result : op->getResults()) {
+    AbstractSparseLattice *resultLattice = getLatticeElement(result);
+    allAtFixpoint &= resultLattice->isAtFixpoint();
+    resultLattices.push_back(resultLattice);
+  }
+  // If all result lattices have reached a fixpoint, there is nothing to do.
+  if (allAtFixpoint)
+    return;
+
+  // The results of a region branch operation are determined by control-flow.
+  if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+    return visitRegionSuccessors({branch}, branch,
+                                 /*successorIndex=*/llvm::None, resultLattices);
+  }
+
+  // The results of a call operation are determined by the callgraph.
+  if (auto call = dyn_cast<CallOpInterface>(op)) {
+    const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
+    // If not all return sites are known, then conservatively assume we can't
+    // reason about the data-flow.
+    if (!predecessors->allPredecessorsKnown())
+      return markAllPessimisticFixpoint(resultLattices);
+    for (Operation *predecessor : predecessors->getKnownPredecessors())
+      for (auto it : llvm::zip(predecessor->getOperands(), resultLattices))
+        join(std::get<1>(it), *getLatticeElementFor(op, std::get<0>(it)));
+    return;
+  }
+
+  // Grab the lattice elements of the operands.
+  SmallVector<const AbstractSparseLattice *> operandLattices;
+  operandLattices.reserve(op->getNumOperands());
+  for (Value operand : op->getOperands()) {
+    AbstractSparseLattice *operandLattice = getLatticeElement(operand);
+    operandLattice->useDefSubscribe(this);
+    // If any of the operand states are not initialized, bail out.
+    if (operandLattice->isUninitialized())
+      return;
+    operandLattices.push_back(operandLattice);
+  }
+
+  // Invoke the operation transfer function.
+  visitOperationImpl(op, operandLattices, resultLattices);
+}
+
+void AbstractSparseDataFlowAnalysis::visitBlock(Block *block) {
+  // Exit early on blocks with no arguments.
+  if (block->getNumArguments() == 0)
+    return;
+
+  // If the block is not executable, bail out.
+  if (!getOrCreate<Executable>(block)->isLive())
+    return;
+
+  // Get the argument lattices.
+  SmallVector<AbstractSparseLattice *> argLattices;
+  argLattices.reserve(block->getNumArguments());
+  bool allAtFixpoint = true;
+  for (BlockArgument argument : block->getArguments()) {
+    AbstractSparseLattice *argLattice = getLatticeElement(argument);
+    allAtFixpoint &= argLattice->isAtFixpoint();
+    argLattices.push_back(argLattice);
+  }
+  // If all argument lattices have reached their fixpoints, then there is
+  // nothing to do.
+  if (allAtFixpoint)
+    return;
+
+  // The argument lattices of entry blocks are set by region control-flow or the
+  // callgraph.
+  if (block->isEntryBlock()) {
+    // Check if this block is the entry block of a callable region.
+    auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
+    if (callable && callable.getCallableRegion() == block->getParent()) {
+      const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
+      // If not all callsites are known, conservatively mark all lattices as
+      // having reached their pessimistic fixpoints.
+      if (!callsites->allPredecessorsKnown())
+        return markAllPessimisticFixpoint(argLattices);
+      for (Operation *callsite : callsites->getKnownPredecessors()) {
+        auto call = cast<CallOpInterface>(callsite);
+        for (auto it : llvm::zip(call.getArgOperands(), argLattices))
+          join(std::get<1>(it), *getLatticeElementFor(block, std::get<0>(it)));
+      }
+      return;
+    }
+
+    // Check if the lattices can be determined from region control flow.
+    if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
+      return visitRegionSuccessors(
+          block, branch, block->getParent()->getRegionNumber(), argLattices);
+    }
+
+    // Otherwise, we can't reason about the data-flow.
+    return markAllPessimisticFixpoint(argLattices);
+  }
+
+  // Iterate over the predecessors of the non-entry block.
+  for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
+       it != e; ++it) {
+    Block *predecessor = *it;
+
+    // If the edge from the predecessor block to the current block is not live,
+    // bail out.
+    auto *edgeExecutable =
+        getOrCreate<Executable>(getProgramPoint<CFGEdge>(predecessor, block));
+    edgeExecutable->blockContentSubscribe(this);
+    if (!edgeExecutable->isLive())
+      continue;
+
+    // Check if we can reason about the data-flow from the predecessor.
+    if (auto branch =
+            dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
+      SuccessorOperands operands =
+          branch.getSuccessorOperands(it.getSuccessorIndex());
+      for (auto &it : llvm::enumerate(argLattices)) {
+        if (Value operand = operands[it.index()]) {
+          join(it.value(), *getLatticeElementFor(block, operand));
+        } else {
+          // Conservatively mark internally produced arguments as having reached
+          // their pessimistic fixpoint.
+          markPessimisticFixpoint(it.value());
+        }
+      }
+    } else {
+      return markAllPessimisticFixpoint(argLattices);
+    }
+  }
+}
+
+void AbstractSparseDataFlowAnalysis::visitRegionSuccessors(
+    ProgramPoint point, RegionBranchOpInterface branch,
+    Optional<unsigned> successorIndex,
+    ArrayRef<AbstractSparseLattice *> lattices) {
+  const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
+  assert(predecessors->allPredecessorsKnown() &&
+         "unexpected unresolved region successors");
+
+  for (Operation *op : predecessors->getKnownPredecessors()) {
+    // Get the incoming successor operands.
+    Optional<OperandRange> operands;
+
+    // Check if the predecessor is the parent op.
+    if (op == branch) {
+      operands = branch.getSuccessorEntryOperands(successorIndex);
+      // Otherwise, try to deduce the operands from a region return-like op.
+    } else {
+      assert(op->hasTrait<OpTrait::IsTerminator>() && "expected a terminator");
+      if (isRegionReturnLike(op))
+        operands = getRegionBranchSuccessorOperands(op, successorIndex);
+    }
+
+    if (!operands) {
+      // We can't reason about the data-flow.
+      return markAllPessimisticFixpoint(lattices);
+    }
+
+    ValueRange inputs = predecessors->getSuccessorInputs(op);
+    assert(inputs.size() == operands->size() &&
+           "expected the same number of successor inputs as operands");
+
+    // TODO: This was updated to be exposed upstream.
+    unsigned firstIndex = 0;
+    if (inputs.size() != lattices.size()) {
+      if (inputs.empty()) {
+        markAllPessimisticFixpoint(lattices);
+        return;
+      }
+      firstIndex = inputs.front().cast<BlockArgument>().getArgNumber();
+      markAllPessimisticFixpoint(lattices.take_front(firstIndex));
+      markAllPessimisticFixpoint(
+          lattices.drop_front(firstIndex + inputs.size()));
+    }
+
+    for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex)))
+      join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
+  }
+}
+
+const AbstractSparseLattice *
+AbstractSparseDataFlowAnalysis::getLatticeElementFor(ProgramPoint point,
+                                                     Value value) {
+  AbstractSparseLattice *state = getLatticeElement(value);
+  addDependency(state, point);
+  return state;
+}
+
+void AbstractSparseDataFlowAnalysis::markPessimisticFixpoint(
+    AbstractSparseLattice *lattice) {
+  propagateIfChanged(lattice, lattice->markPessimisticFixpoint());
+}
+
+void AbstractSparseDataFlowAnalysis::markAllPessimisticFixpoint(
+    ArrayRef<AbstractSparseLattice *> lattices) {
+  for (AbstractSparseLattice *lattice : lattices) {
+    markPessimisticFixpoint(lattice);
+  }
+}
+
+void AbstractSparseDataFlowAnalysis::join(AbstractSparseLattice *lhs,
+                                          const AbstractSparseLattice &rhs) {
+  propagateIfChanged(lhs, lhs->join(rhs));
+}
+
+//===----------------------------------------------------------------------===//
+// SparseConstantPropagation
+//===----------------------------------------------------------------------===//
+
+void SparseConstantPropagation::visitOperation(
+    Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
+    ArrayRef<Lattice<ConstantValue> *> results) {
+  LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
+
+  // Don't try to simulate the results of a region operation as we can't
+  // guarantee that folding will be out-of-place. We don't allow in-place
+  // folds as the desire here is for simulated execution, and not general
+  // folding.
+  if (op->getNumRegions())
+    return;
+
+  SmallVector<Attribute, 8> constantOperands;
+  constantOperands.reserve(op->getNumOperands());
+  for (auto *operandLattice : operands)
+    constantOperands.push_back(operandLattice->getValue().getConstantValue());
+
+  // Save the original operands and attributes just in case the operation
+  // folds in-place. The constant passed in may not correspond to the real
+  // runtime value, so in-place updates are not allowed.
+  SmallVector<Value, 8> originalOperands(op->getOperands());
+  DictionaryAttr originalAttrs = op->getAttrDictionary();
+
+  // Simulate the result of folding this operation to a constant. If folding
+  // fails or was an in-place fold, mark the results as overdefined.
+  SmallVector<OpFoldResult, 8> foldResults;
+  foldResults.reserve(op->getNumResults());
+  if (failed(op->fold(constantOperands, foldResults))) {
+    markAllPessimisticFixpoint(results);
+    return;
+  }
+
+  // If the folding was in-place, mark the results as overdefined and reset
+  // the operation. We don't allow in-place folds as the desire here is for
+  // simulated execution, and not general folding.
+  if (foldResults.empty()) {
+    op->setOperands(originalOperands);
+    op->setAttrs(originalAttrs);
+    return;
+  }
+
+  // Merge the fold results into the lattice for this operation.
+  assert(foldResults.size() == op->getNumResults() && "invalid result size");
+  for (const auto it : llvm::zip(results, foldResults)) {
+    Lattice<ConstantValue> *lattice = std::get<0>(it);
+
+    // Merge in the result of the fold, either a constant or a value.
+    OpFoldResult foldResult = std::get<1>(it);
+    if (Attribute attr = foldResult.dyn_cast<Attribute>()) {
+      LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
+      propagateIfChanged(lattice,
+                         lattice->join(ConstantValue(attr, op->getDialect())));
+    } else {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Folded to value: " << foldResult.get<Value>() << "\n");
+      AbstractSparseDataFlowAnalysis::join(
+          lattice, *getLatticeElement(foldResult.get<Value>()));
+    }
+  }
+}

diff  --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp
index 11d55e7454a0a..548950aaed81a 100644
--- a/mlir/lib/Transforms/SCCP.cpp
+++ b/mlir/lib/Transforms/SCCP.cpp
@@ -15,151 +15,16 @@
 //===----------------------------------------------------------------------===//
 
 #include "PassDetail.h"
-#include "mlir/Analysis/DataFlowAnalysis.h"
+#include "mlir/Analysis/SparseDataFlowAnalysis.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dialect.h"
-#include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/Passes.h"
-#include "llvm/Support/Debug.h"
-
-#define DEBUG_TYPE "sccp"
 
 using namespace mlir;
 
-//===----------------------------------------------------------------------===//
-// SCCP Analysis
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct SCCPLatticeValue {
-  SCCPLatticeValue(Attribute constant = {}, Dialect *dialect = nullptr)
-      : constant(constant), constantDialect(dialect) {}
-
-  /// The pessimistic state of SCCP is non-constant.
-  static SCCPLatticeValue getPessimisticValueState(MLIRContext *context) {
-    return SCCPLatticeValue();
-  }
-  static SCCPLatticeValue getPessimisticValueState(Value value) {
-    return SCCPLatticeValue();
-  }
-
-  /// Equivalence for SCCP only accounts for the constant, not the originating
-  /// dialect.
-  bool operator==(const SCCPLatticeValue &rhs) const {
-    return constant == rhs.constant;
-  }
-
-  /// To join the state of two values, we simply check for equivalence.
-  static SCCPLatticeValue join(const SCCPLatticeValue &lhs,
-                               const SCCPLatticeValue &rhs) {
-    return lhs == rhs ? lhs : SCCPLatticeValue();
-  }
-
-  /// The constant attribute value.
-  Attribute constant;
-
-  /// The dialect the constant originated from. This is not used as part of the
-  /// key, and is only needed to materialize the held constant if necessary.
-  Dialect *constantDialect;
-};
-
-struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> {
-  using ForwardDataFlowAnalysis<SCCPLatticeValue>::ForwardDataFlowAnalysis;
-  ~SCCPAnalysis() override = default;
-
-  ChangeResult
-  visitOperation(Operation *op,
-                 ArrayRef<LatticeElement<SCCPLatticeValue> *> operands) final {
-
-    LLVM_DEBUG(llvm::dbgs() << "SCCP: Visiting operation: " << *op << "\n");
-
-    // Don't try to simulate the results of a region operation as we can't
-    // guarantee that folding will be out-of-place. We don't allow in-place
-    // folds as the desire here is for simulated execution, and not general
-    // folding.
-    if (op->getNumRegions())
-      return markAllPessimisticFixpoint(op->getResults());
-
-    SmallVector<Attribute> constantOperands(
-        llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
-          return value->getValue().constant;
-        }));
-
-    // Save the original operands and attributes just in case the operation
-    // folds in-place. The constant passed in may not correspond to the real
-    // runtime value, so in-place updates are not allowed.
-    SmallVector<Value, 8> originalOperands(op->getOperands());
-    DictionaryAttr originalAttrs = op->getAttrDictionary();
-
-    // Simulate the result of folding this operation to a constant. If folding
-    // fails or was an in-place fold, mark the results as overdefined.
-    SmallVector<OpFoldResult, 8> foldResults;
-    foldResults.reserve(op->getNumResults());
-    if (failed(op->fold(constantOperands, foldResults)))
-      return markAllPessimisticFixpoint(op->getResults());
-
-    // If the folding was in-place, mark the results as overdefined and reset
-    // the operation. We don't allow in-place folds as the desire here is for
-    // simulated execution, and not general folding.
-    if (foldResults.empty()) {
-      op->setOperands(originalOperands);
-      op->setAttrs(originalAttrs);
-      return markAllPessimisticFixpoint(op->getResults());
-    }
-
-    // Merge the fold results into the lattice for this operation.
-    assert(foldResults.size() == op->getNumResults() && "invalid result size");
-    Dialect *dialect = op->getDialect();
-    ChangeResult result = ChangeResult::NoChange;
-    for (unsigned i = 0, e = foldResults.size(); i != e; ++i) {
-      LatticeElement<SCCPLatticeValue> &lattice =
-          getLatticeElement(op->getResult(i));
-
-      // Merge in the result of the fold, either a constant or a value.
-      OpFoldResult foldResult = foldResults[i];
-      if (Attribute attr = foldResult.dyn_cast<Attribute>())
-        result |= lattice.join(SCCPLatticeValue(attr, dialect));
-      else
-        result |= lattice.join(getLatticeElement(foldResult.get<Value>()));
-    }
-    return result;
-  }
-
-  /// Implementation of `getSuccessorsForOperands` that uses constant operands
-  /// to potentially remove dead successors.
-  LogicalResult getSuccessorsForOperands(
-      BranchOpInterface branch,
-      ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
-      SmallVectorImpl<Block *> &successors) final {
-    SmallVector<Attribute> constantOperands(
-        llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
-          return value->getValue().constant;
-        }));
-    if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) {
-      successors.push_back(singleSucc);
-      return success();
-    }
-    return failure();
-  }
-
-  /// Implementation of `getSuccessorsForOperands` that uses constant operands
-  /// to potentially remove dead region successors.
-  void getSuccessorsForOperands(
-      RegionBranchOpInterface branch, Optional<unsigned> sourceIndex,
-      ArrayRef<LatticeElement<SCCPLatticeValue> *> operands,
-      SmallVectorImpl<RegionSuccessor> &successors) final {
-    SmallVector<Attribute> constantOperands(
-        llvm::map_range(operands, [](LatticeElement<SCCPLatticeValue> *value) {
-          return value->getValue().constant;
-        }));
-    branch.getSuccessorRegions(sourceIndex, constantOperands, successors);
-  }
-};
-} // namespace
-
 //===----------------------------------------------------------------------===//
 // SCCP Rewrites
 //===----------------------------------------------------------------------===//
@@ -167,21 +32,21 @@ struct SCCPAnalysis : public ForwardDataFlowAnalysis<SCCPLatticeValue> {
 /// Replace the given value with a constant if the corresponding lattice
 /// represents a constant. Returns success if the value was replaced, failure
 /// otherwise.
-static LogicalResult replaceWithConstant(SCCPAnalysis &analysis,
+static LogicalResult replaceWithConstant(DataFlowSolver &solver,
                                          OpBuilder &builder,
                                          OperationFolder &folder, Value value) {
-  LatticeElement<SCCPLatticeValue> *lattice =
-      analysis.lookupLatticeElement(value);
+  auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
   if (!lattice)
     return failure();
-  SCCPLatticeValue &latticeValue = lattice->getValue();
-  if (!latticeValue.constant)
+  const ConstantValue &latticeValue = lattice->getValue();
+  if (!latticeValue.getConstantValue())
     return failure();
 
   // Attempt to materialize a constant for the given value.
-  Dialect *dialect = latticeValue.constantDialect;
-  Value constant = folder.getOrCreateConstant(
-      builder, dialect, latticeValue.constant, value.getType(), value.getLoc());
+  Dialect *dialect = latticeValue.getConstantDialect();
+  Value constant = folder.getOrCreateConstant(builder, dialect,
+                                              latticeValue.getConstantValue(),
+                                              value.getType(), value.getLoc());
   if (!constant)
     return failure();
 
@@ -192,7 +57,7 @@ static LogicalResult replaceWithConstant(SCCPAnalysis &analysis,
 /// Rewrite the given regions using the computing analysis. This replaces the
 /// uses of all values that have been computed to be constant, and erases as
 /// many newly dead operations.
-static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
+static void rewrite(DataFlowSolver &solver, MLIRContext *context,
                     MutableArrayRef<Region> initialRegions) {
   SmallVector<Block *> worklist;
   auto addToWorklist = [&](MutableArrayRef<Region> regions) {
@@ -216,7 +81,7 @@ static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
       bool replacedAll = op.getNumResults() != 0;
       for (Value res : op.getResults())
         replacedAll &=
-            succeeded(replaceWithConstant(analysis, builder, folder, res));
+            succeeded(replaceWithConstant(solver, builder, folder, res));
 
       // If all of the results of the operation were replaced, try to erase
       // the operation completely.
@@ -233,7 +98,7 @@ static void rewrite(SCCPAnalysis &analysis, MLIRContext *context,
     // Replace any block arguments with constants.
     builder.setInsertionPointToStart(block);
     for (BlockArgument arg : block->getArguments())
-      (void)replaceWithConstant(analysis, builder, folder, arg);
+      (void)replaceWithConstant(solver, builder, folder, arg);
   }
 }
 
@@ -250,9 +115,14 @@ struct SCCP : public SCCPBase<SCCP> {
 void SCCP::runOnOperation() {
   Operation *op = getOperation();
 
-  SCCPAnalysis analysis(op->getContext());
-  analysis.run(op);
-  rewrite(analysis, op->getContext(), op->getRegions());
+  DataFlowSolver solver;
+  solver.load<DeadCodeAnalysis>();
+  solver.load<SparseConstantPropagation>();
+  if (failed(solver.initializeAndRun(op))) {
+    op->emitError("SCCP analysis failed\n");
+    return signalPassFailure();
+  }
+  rewrite(solver, op->getContext(), op->getRegions());
 }
 
 std::unique_ptr<Pass> mlir::createSCCPPass() {


        


More information about the llvm-branch-commits mailing list