[Mlir-commits] [mlir] [mlir] support non-interprocedural dataflow analyses (PR #75583)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 15 02:34:30 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Oleksandr "Alex" Zinenko (ftynse)

<details>
<summary>Changes</summary>

The core implementation of the dataflow anlysis framework is interpocedural by design. While this offers better analysis precision, it also comes with additional cost as it takes longer for the analysis to reach the fixpoint state. Add a configuration mechanism to the dataflow solver to control whether it operates inteprocedurally or not to offer clients a choice.

As a positive side effect, this change also adds hooks for explicitly processing external/opaque function calls in the dataflow analyses, e.g., based off of attributes present in the the function declaration or call operation such as alias scopes and modref available in the LLVM dialect.

This change should not affect existing analyses and the default solver configuration remains interprocedural.

---

Patch is 67.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75583.diff


12 Files Affected:

- (modified) mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h (+30-10) 
- (modified) mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h (+55) 
- (modified) mlir/include/mlir/Analysis/DataFlowFramework.h (+38) 
- (modified) mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp (+29-13) 
- (modified) mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp (+38-19) 
- (modified) mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir (+103-22) 
- (modified) mlir/test/Analysis/DataFlow/test-next-access.mlir (+108-32) 
- (modified) mlir/test/Analysis/DataFlow/test-written-to.mlir (+75-15) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp (+85-18) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h (+76-20) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp (+88-19) 
- (modified) mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp (+46-3) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 6a1335bab8bf6e..088b6cd7d698fc 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -27,8 +27,9 @@ namespace dataflow {
 // CallControlFlowAction
 //===----------------------------------------------------------------------===//
 
-/// Indicates whether the control enters or exits the callee.
-enum class CallControlFlowAction { EnterCallee, ExitCallee };
+/// Indicates whether the control enters, exits, or skips over the callee (in
+/// the case of external functions).
+enum class CallControlFlowAction { EnterCallee, ExitCallee, ExternalCallee };
 
 //===----------------------------------------------------------------------===//
 // AbstractDenseLattice
@@ -131,14 +132,21 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
 
   /// Propagate the dense lattice forward along the call control flow edge,
   /// which can be either entering or exiting the callee. Default implementation
-  /// just meets the states, meaning that operations implementing
-  /// `CallOpInterface` don't have any effect on the lattice that isn't already
-  /// expressed by the interface itself.
+  /// for enter and exit callee actions just meets the states, meaning that
+  /// operations implementing `CallOpInterface` don't have any effect on the
+  /// lattice that isn't already expressed by the interface itself. Default
+  /// implementation for the external callee action additionally sets the
+  /// "after" lattice to the entry state.
   virtual void visitCallControlFlowTransfer(CallOpInterface call,
                                             CallControlFlowAction action,
                                             const AbstractDenseLattice &before,
                                             AbstractDenseLattice *after) {
     join(after, before);
+    // Note that `setToEntryState` may be a "partial fixpoint" for some
+    // lattices, e.g., lattices that are lists of maps of other lattices will
+    // only set fixpoint for "known" lattices.
+    if (action == CallControlFlowAction::ExternalCallee)
+      setToEntryState(after);
   }
 
   /// Visit a program point within a region branch operation with predecessors
@@ -155,7 +163,9 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
 
   /// Visit an operation for which the data flow is described by the
   /// `CallOpInterface`.
-  void visitCallOperation(CallOpInterface call, AbstractDenseLattice *after);
+  void visitCallOperation(CallOpInterface call,
+                          const AbstractDenseLattice &before,
+                          AbstractDenseLattice *after);
 };
 
 //===----------------------------------------------------------------------===//
@@ -361,14 +371,22 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
 
   /// Propagate the dense lattice backwards along the call control flow edge,
   /// which can be either entering or exiting the callee. Default implementation
-  /// just meets the states, meaning that operations implementing
-  /// `CallOpInterface` don't have any effect on hte lattice that isn't already
-  /// expressed by the interface itself.
+  /// for enter and exit callee action just meets the states, meaning that
+  /// operations implementing `CallOpInterface` don't have any effect on the
+  /// lattice that isn't already expressed by the interface itself. Default
+  /// implementation for external callee action additional sets the result to
+  /// the exit (fixpoint) state.
   virtual void visitCallControlFlowTransfer(CallOpInterface call,
                                             CallControlFlowAction action,
                                             const AbstractDenseLattice &after,
                                             AbstractDenseLattice *before) {
     meet(before, after);
+
+    // Note that `setToExitState` may be a "partial fixpoint" for some lattices,
+    // e.g., lattices that are lists of maps of other lattices will only
+    // set fixpoint for "known" lattices.
+    if (action == CallControlFlowAction::ExternalCallee)
+      setToExitState(before);
   }
 
 private:
@@ -394,7 +412,9 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   ///     otherwise,
   ///   - meet that state with the state before the call-like op, or use the
   ///     custom logic if overridden by concrete analyses.
-  void visitCallOperation(CallOpInterface call, AbstractDenseLattice *before);
+  void visitCallOperation(CallOpInterface call,
+                          const AbstractDenseLattice &after,
+                          AbstractDenseLattice *before);
 
   /// Symbol table for call-level control flow.
   SymbolTableCollection &symbolTable;
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 5a9a36159b56c5..b65ac8bb1dec27 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -17,6 +17,7 @@
 
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "llvm/ADT/SmallPtrSet.h"
 
@@ -199,6 +200,12 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
                      ArrayRef<const AbstractSparseLattice *> operandLattices,
                      ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
 
+  /// The transfer function for calls to external functions.
+  virtual void visitExternalCallImpl(
+      CallOpInterface call,
+      ArrayRef<const AbstractSparseLattice *> argumentLattices,
+      ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
+
   /// Given an operation with region control-flow, the lattices of the operands,
   /// and a region successor, compute the lattice values for block arguments
   /// that are not accounted for by the branching control flow (ex. the bounds
@@ -271,6 +278,14 @@ class SparseForwardDataFlowAnalysis
   virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
                               ArrayRef<StateT *> results) = 0;
 
+  /// Visit a call operation to an externally defined function given the
+  /// lattices of its arguments.
+  virtual void visitExternalCall(CallOpInterface call,
+                                 ArrayRef<const StateT *> argumentLattices,
+                                 ArrayRef<StateT *> resultLattices) {
+    setAllToEntryStates(resultLattices);
+  }
+
   /// Given an operation with possible region control-flow, the lattices of the
   /// operands, and a region successor, compute the lattice values for block
   /// arguments that are not accounted for by the branching control flow (ex.
@@ -321,6 +336,17 @@ class SparseForwardDataFlowAnalysis
         {reinterpret_cast<StateT *const *>(resultLattices.begin()),
          resultLattices.size()});
   }
+  void visitExternalCallImpl(
+      CallOpInterface call,
+      ArrayRef<const AbstractSparseLattice *> argumentLattices,
+      ArrayRef<AbstractSparseLattice *> resultLattices) override {
+    visitExternalCall(
+        call,
+        {reinterpret_cast<const StateT *const *>(argumentLattices.begin()),
+         argumentLattices.size()},
+        {reinterpret_cast<StateT *const *>(resultLattices.begin()),
+         resultLattices.size()});
+  }
   void visitNonControlFlowArgumentsImpl(
       Operation *op, const RegionSuccessor &successor,
       ArrayRef<AbstractSparseLattice *> argLattices,
@@ -363,6 +389,11 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
       Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
       ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
 
+  /// The transfer function for calls to external functions.
+  virtual void visitExternalCallImpl(
+      CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
+      ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
+
   // Visit operands on branch instructions that are not forwarded.
   virtual void visitBranchOperand(OpOperand &operand) = 0;
 
@@ -444,6 +475,19 @@ class SparseBackwardDataFlowAnalysis
   virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
                               ArrayRef<const StateT *> results) = 0;
 
+  /// Visit a call to an external function. This function is expected to set
+  /// lattice values of the call operands. By default, calls `visitCallOperand`
+  /// for all operands.
+  virtual void visitExternalCall(CallOpInterface call,
+                                 ArrayRef<StateT *> argumentLattices,
+                                 ArrayRef<const StateT *> resultLattices) {
+    (void)argumentLattices;
+    (void)resultLattices;
+    for (OpOperand &operand : call->getOpOperands()) {
+      visitCallOperand(operand);
+    }
+  };
+
 protected:
   /// Get the lattice element for a value.
   StateT *getLatticeElement(Value value) override {
@@ -474,6 +518,17 @@ class SparseBackwardDataFlowAnalysis
         {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
          resultLattices.size()});
   }
+
+  void visitExternalCallImpl(
+      CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
+      ArrayRef<const AbstractSparseLattice *> resultLattices) override {
+    visitExternalCall(
+        call,
+        {reinterpret_cast<StateT *const *>(operandLattices.begin()),
+         operandLattices.size()},
+        {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
+         resultLattices.size()});
+  }
 };
 
 } // end namespace dataflow
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index c27615b52a12b8..541cdb1e237c1b 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -175,6 +175,32 @@ struct ProgramPoint
 /// Forward declaration of the data-flow analysis class.
 class DataFlowAnalysis;
 
+//===----------------------------------------------------------------------===//
+// DataFlowConfig
+//===----------------------------------------------------------------------===//
+
+/// Configuration class for data flow solver and child analyses. Follows the
+/// fluent API pattern.
+class DataFlowConfig {
+public:
+  DataFlowConfig() = default;
+
+  /// Set whether the solver should operate interpocedurally, i.e. enter the
+  /// callee body when available. Interprocedural analyses may be more precise,
+  /// but also more expensive as more states need to be computed and the
+  /// fixpoint convergence takes longer.
+  DataFlowConfig &setInterprocedural(bool enable) {
+    interprocedural = enable;
+    return *this;
+  }
+
+  /// Return `true` if the solver operates interprocedurally, `false` otherwise.
+  bool isInterprocedural() const { return interprocedural; }
+
+private:
+  bool interprocedural = true;
+};
+
 //===----------------------------------------------------------------------===//
 // DataFlowSolver
 //===----------------------------------------------------------------------===//
@@ -195,6 +221,9 @@ class DataFlowAnalysis;
 /// TODO: Optimize the internal implementation of the solver.
 class DataFlowSolver {
 public:
+  explicit DataFlowSolver(const DataFlowConfig &config = DataFlowConfig())
+      : config(config) {}
+
   /// Load an analysis into the solver. Return the analysis instance.
   template <typename AnalysisT, typename... Args>
   AnalysisT *load(Args &&...args);
@@ -236,7 +265,13 @@ class DataFlowSolver {
   /// dependent work items to the back of the queue.
   void propagateIfChanged(AnalysisState *state, ChangeResult changed);
 
+  /// Get the configuration of the solver.
+  const DataFlowConfig &getConfig() const { return config; }
+
 private:
+  /// Configuration of the dataflow solver.
+  DataFlowConfig config;
+
   /// The solver's work queue. Work items can be inserted to the front of the
   /// queue to be processed greedily, speeding up computations that otherwise
   /// quickly degenerate to quadratic due to propagation of state updates.
@@ -423,6 +458,9 @@ class DataFlowAnalysis {
     return state;
   }
 
+  /// Return the configuration of the solver used for this analysis.
+  const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+
 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
   /// When compiling with debugging, keep a name for the analyis.
   StringRef debugName;
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index a6c9f7d7da225e..08d89d6db788c8 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -54,12 +54,22 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) {
 }
 
 void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
-    CallOpInterface call, AbstractDenseLattice *after) {
+    CallOpInterface call, const AbstractDenseLattice &before,
+    AbstractDenseLattice *after) {
+  // Allow for customizing the behavior of calls to external symbols, including
+  // when the analysis is explicitly marked as non-interprocedural.
+  auto callable =
+      dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+  if (!getSolverConfig().isInterprocedural() ||
+      (callable && !callable.getCallableRegion())) {
+    return visitCallControlFlowTransfer(
+        call, CallControlFlowAction::ExternalCallee, before, after);
+  }
 
   const auto *predecessors =
       getOrCreateFor<PredecessorState>(call.getOperation(), call);
-  // If not all return sites are known, then conservatively assume we can't
-  // reason about the data-flow.
+  // Otherwise, if not all return sites are known, then conservatively assume we
+  // can't reason about the data-flow.
   if (!predecessors->allPredecessorsKnown())
     return setToEntryState(after);
 
@@ -108,7 +118,7 @@ void AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
   // If this is a call operation, then join its lattices across known return
   // sites.
   if (auto call = dyn_cast<CallOpInterface>(op))
-    return visitCallOperation(call, after);
+    return visitCallOperation(call, *before, after);
 
   // Invoke the operation transfer function.
   visitOperationImpl(op, *before, after);
@@ -130,8 +140,10 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
     if (callable && callable.getCallableRegion() == block->getParent()) {
       const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
       // If not all callsites are known, conservatively mark all lattices as
-      // having reached their pessimistic fixpoints.
-      if (!callsites->allPredecessorsKnown())
+      // having reached their pessimistic fixpoints. Do the same if
+      // interprocedural analysis is not enabled.
+      if (!callsites->allPredecessorsKnown() ||
+          !getSolverConfig().isInterprocedural())
         return setToEntryState(after);
       for (Operation *callsite : callsites->getKnownPredecessors()) {
         // Get the dense lattice before the callsite.
@@ -267,18 +279,20 @@ LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
 }
 
 void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
-    CallOpInterface call, AbstractDenseLattice *before) {
+    CallOpInterface call, const AbstractDenseLattice &after,
+    AbstractDenseLattice *before) {
   // Find the callee.
   Operation *callee = call.resolveCallable(&symbolTable);
   auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
   if (!callable)
     return setToExitState(before);
 
-  // No region means the callee is only declared in this module and we shouldn't
-  // assume anything about it.
+  // No region means the callee is only declared in this module.
   Region *region = callable.getCallableRegion();
-  if (!region || region->empty())
-    return setToExitState(before);
+  if (!region || region->empty() || !getSolverConfig().isInterprocedural()) {
+    return visitCallControlFlowTransfer(
+        call, CallControlFlowAction::ExternalCallee, after, before);
+  }
 
   // Call-level control flow specifies the data flow here.
   //
@@ -324,7 +338,7 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
     return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
                                       before);
   if (auto call = dyn_cast<CallOpInterface>(op))
-    return visitCallOperation(call, before);
+    return visitCallOperation(call, *after, before);
 
   // Invoke the operation transfer function.
   visitOperationImpl(op, *after, before);
@@ -359,8 +373,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
       const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
       // If not all call sites are known, conservative mark all lattices as
       // having reached their pessimistic fix points.
-      if (!callsites->allPredecessorsKnown())
+      if (!callsites->allPredecessorsKnown() ||
+          !getSolverConfig().isInterprocedural()) {
         return setToExitState(before);
+      }
 
       for (Operation *callsite : callsites->getKnownPredecessors()) {
         const AbstractDenseLattice *after;
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 9f544d656df925..b47bba16fd9024 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -116,8 +116,27 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
                                  resultLattices);
   }
 
-  // The results of a call operation are determined by the callgraph.
+  // Grab the lattice elements of the operands.
+  SmallVector<const AbstractSparseLattice *> operandLattices;
+  operandLattices.reserve(op->getNumOperands());
+  for (Value operand : op->getOperands()) {
+    AbstractSparseLattice *operandLattice = getLatticeElement(operand);
+    operandLattice->useDefSubscribe(this);
+    operandLattices.push_back(operandLattice);
+  }
+
   if (auto call = dyn_cast<CallOpInterface>(op)) {
+    // If the call operation is to an external function, attempt to infer the
+    // results from the call arguments.
+    auto callable =
+        dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+    if (!getSolverConfig().isInterprocedural() ||
+        (callable && !callable.getCallableRegion())) {
+      return visitExternalCallImpl(call, operandLattices, resultLattices);
+    }
+
+    // Otherwise, the results of a call operation are determined by the
+    // callgraph.
     const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
     // If not all return sites are known, then conservatively assume we can't
     // reason about the data-flow.
@@ -129,15 +148,6 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
     return;
   }
 
-  // Grab the lattice elements of the operands.
-  SmallVector<const AbstractSparseLattice *> operandLattices;
-  operandLattices.reserve(op->getNumOperands());
-  for (Value operand : op->getOperands()) {
-    AbstractSparseLattice *operandLattice = getLatticeElement(operand);
-    operandLattice->useDefSubscribe(this);
-    operandLattices.push_back(operandLattice);
-  }
-
   // Invoke the operation transfer function.
   visitOperationImpl(op, operandLattices, resultLattices);
 }
@@ -168,8 +178,10 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
       const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
       // If not all callsites are known, conservatively mark all lattices as
       // having reached their pessimistic fixpoints.
-      if (!callsites->allPredecessorsKnown())
+      if (!callsites->allPredecessorsKnown() ||
+          !getSolverConfig().isInterprocedural()) {
         return setAllToEntryStates(argLattices);
+      }
       for (Operation *callsite : callsites->getKnownPredecessors()) {
         auto call = cast<CallOpInterface>(callsite);
         for (auto it : ll...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/75583


More information about the Mlir-commits mailing list