[Mlir-commits] [mlir] 5a531b1 - [mlir] NFC: Add data flow analysis extension points (#142549)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 4 05:15:09 PDT 2025


Author: Vadim Curcă
Date: 2025-06-04T14:15:05+02:00
New Revision: 5a531b115844a038d7bd0108ebafe6bacbef75e3

URL: https://github.com/llvm/llvm-project/commit/5a531b115844a038d7bd0108ebafe6bacbef75e3
DIFF: https://github.com/llvm/llvm-project/commit/5a531b115844a038d7bd0108ebafe6bacbef75e3.diff

LOG: [mlir] NFC: Add data flow analysis extension points (#142549)

This commit introduces `visitCallOperation` and `visitCallableOperation`
extension points in the sparse data flow analysis framework. This
allows, for example, to make the analysis less conservative, without a
lot of code duplication, propagating information even if not all the
call or return sites are known.

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
    mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 1b2c679176107..3f8874d02afad 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -235,6 +235,30 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
   /// Join the lattice element and propagate and update if it changed.
   void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
 
+  /// Visits a call operation. Given the operand lattices, sets the result
+  /// lattices. Performs interprocedural data flow as follows: if the call
+  /// operation targets an external function, or if the solver is not
+  /// interprocedural, attempts to infer the results from the call arguments
+  /// using the user-provided `visitExternalCallImpl`. Otherwise, computes the
+  /// result lattices from the return sites if all return sites are known;
+  /// otherwise, conservatively marks the result lattices as having reached
+  /// their pessimistic fixpoints.
+  /// This method can be overridden to, for example, be less conservative and
+  /// propagate the information even if some return sites are unknown.
+  virtual LogicalResult
+  visitCallOperation(CallOpInterface call,
+                     ArrayRef<const AbstractSparseLattice *> operandLattices,
+                     ArrayRef<AbstractSparseLattice *> resultLattices);
+
+  /// Visits a callable operation. Computes the argument lattices from call
+  /// sites if all call sites are known; otherwise, conservatively marks them
+  /// as having reached their pessimistic fixpoints.
+  /// This method can be overridden to, for example, be less conservative and
+  /// propagate the information even if some call sites are unknown.
+  virtual void
+  visitCallableOperation(CallableOpInterface callable,
+                         ArrayRef<AbstractSparseLattice *> argLattices);
+
 private:
   /// Recursively initialize the analysis on nested operations and blocks.
   LogicalResult initializeRecursively(Operation *op);
@@ -430,6 +454,16 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   /// Join the lattice element and propagate and update if it changed.
   void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
 
+  /// Visits a callable operation. If all the call sites are known computes the
+  /// operand lattices of `op` from the result lattices of all the call sites;
+  /// otherwise, conservatively marks them as having reached their pessimistic
+  /// fixpoints.
+  /// This method can be overridden to, for example, be less conservative and
+  /// propagate the information even if some call sites are unknown.
+  virtual LogicalResult
+  visitCallableOperation(Operation *op, CallableOpInterface callable,
+                         ArrayRef<AbstractSparseLattice *> operandLattices);
+
 private:
   /// Recursively initialize the analysis on nested operations and blocks.
   LogicalResult initializeRecursively(Operation *op);

diff  --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 0b39d14042493..016e59dcb744e 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -128,34 +128,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
     operandLattices.push_back(operandLattice);
   }
 
-  if (auto call = dyn_cast<CallOpInterface>(op)) {
-    // If the call operation is to an external function, attempt to infer the
-    // results from the call arguments.
-    auto callable =
-        dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
-    if (!getSolverConfig().isInterprocedural() ||
-        (callable && !callable.getCallableRegion())) {
-      visitExternalCallImpl(call, operandLattices, resultLattices);
-      return success();
-    }
-
-    // Otherwise, the results of a call operation are determined by the
-    // callgraph.
-    const auto *predecessors = getOrCreateFor<PredecessorState>(
-        getProgramPointAfter(op), getProgramPointAfter(call));
-    // If not all return sites are known, then conservatively assume we can't
-    // reason about the data-flow.
-    if (!predecessors->allPredecessorsKnown()) {
-      setAllToEntryStates(resultLattices);
-      return success();
-    }
-    for (Operation *predecessor : predecessors->getKnownPredecessors())
-      for (auto &&[operand, resLattice] :
-           llvm::zip(predecessor->getOperands(), resultLattices))
-        join(resLattice,
-             *getLatticeElementFor(getProgramPointAfter(op), operand));
-    return success();
-  }
+  if (auto call = dyn_cast<CallOpInterface>(op))
+    return visitCallOperation(call, operandLattices, resultLattices);
 
   // Invoke the operation transfer function.
   return visitOperationImpl(op, operandLattices, resultLattices);
@@ -183,24 +157,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
   if (block->isEntryBlock()) {
     // Check if this block is the entry block of a callable region.
     auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
-    if (callable && callable.getCallableRegion() == block->getParent()) {
-      const auto *callsites = getOrCreateFor<PredecessorState>(
-          getProgramPointBefore(block), getProgramPointAfter(callable));
-      // If not all callsites are known, conservatively mark all lattices as
-      // having reached their pessimistic fixpoints.
-      if (!callsites->allPredecessorsKnown() ||
-          !getSolverConfig().isInterprocedural()) {
-        return setAllToEntryStates(argLattices);
-      }
-      for (Operation *callsite : callsites->getKnownPredecessors()) {
-        auto call = cast<CallOpInterface>(callsite);
-        for (auto it : llvm::zip(call.getArgOperands(), argLattices))
-          join(std::get<1>(it),
-               *getLatticeElementFor(getProgramPointBefore(block),
-                                     std::get<0>(it)));
-      }
-      return;
-    }
+    if (callable && callable.getCallableRegion() == block->getParent())
+      return visitCallableOperation(callable, argLattices);
 
     // Check if the lattices can be determined from region control flow.
     if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
@@ -248,6 +206,59 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
   }
 }
 
+LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
+    CallOpInterface call,
+    ArrayRef<const AbstractSparseLattice *> operandLattices,
+    ArrayRef<AbstractSparseLattice *> resultLattices) {
+  // If the call operation is to an external function, attempt to infer the
+  // results from the call arguments.
+  auto callable =
+      dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+  if (!getSolverConfig().isInterprocedural() ||
+      (callable && !callable.getCallableRegion())) {
+    visitExternalCallImpl(call, operandLattices, resultLattices);
+    return success();
+  }
+
+  // Otherwise, the results of a call operation are determined by the
+  // callgraph.
+  const auto *predecessors = getOrCreateFor<PredecessorState>(
+      getProgramPointAfter(call), getProgramPointAfter(call));
+  // If not all return sites are known, then conservatively assume we can't
+  // reason about the data-flow.
+  if (!predecessors->allPredecessorsKnown()) {
+    setAllToEntryStates(resultLattices);
+    return success();
+  }
+  for (Operation *predecessor : predecessors->getKnownPredecessors())
+    for (auto &&[operand, resLattice] :
+         llvm::zip(predecessor->getOperands(), resultLattices))
+      join(resLattice,
+           *getLatticeElementFor(getProgramPointAfter(call), operand));
+  return success();
+}
+
+void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
+    CallableOpInterface callable,
+    ArrayRef<AbstractSparseLattice *> argLattices) {
+  Block *entryBlock = &callable.getCallableRegion()->front();
+  const auto *callsites = getOrCreateFor<PredecessorState>(
+      getProgramPointBefore(entryBlock), getProgramPointAfter(callable));
+  // If not all callsites are known, conservatively mark all lattices as
+  // having reached their pessimistic fixpoints.
+  if (!callsites->allPredecessorsKnown() ||
+      !getSolverConfig().isInterprocedural()) {
+    return setAllToEntryStates(argLattices);
+  }
+  for (Operation *callsite : callsites->getKnownPredecessors()) {
+    auto call = cast<CallOpInterface>(callsite);
+    for (auto it : llvm::zip(call.getArgOperands(), argLattices))
+      join(std::get<1>(it),
+           *getLatticeElementFor(getProgramPointBefore(entryBlock),
+                                 std::get<0>(it)));
+  }
+}
+
 void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
     ProgramPoint *point, RegionBranchOpInterface branch,
     RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
@@ -512,31 +523,34 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   if (op->hasTrait<OpTrait::ReturnLike>()) {
     // Going backwards, the operands of the return are derived from the
     // results of all CallOps calling this CallableOp.
-    if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
-      const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
-          getProgramPointAfter(op), getProgramPointAfter(callable));
-      if (callsites->allPredecessorsKnown()) {
-        for (Operation *call : callsites->getKnownPredecessors()) {
-          SmallVector<const AbstractSparseLattice *> callResultLattices =
-              getLatticeElementsFor(getProgramPointAfter(op),
-                                    call->getResults());
-          for (auto [op, result] :
-               llvm::zip(operandLattices, callResultLattices))
-            meet(op, *result);
-        }
-      } else {
-        // If we don't know all the callers, we can't know where the
-        // returned values go. Note that, in particular, this will trigger
-        // for the return ops of any public functions.
-        setAllToExitStates(operandLattices);
-      }
-      return success();
-    }
+    if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
+      return visitCallableOperation(op, callable, operandLattices);
   }
 
   return visitOperationImpl(op, operandLattices, resultLattices);
 }
 
+LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
+    Operation *op, CallableOpInterface callable,
+    ArrayRef<AbstractSparseLattice *> operandLattices) {
+  const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
+      getProgramPointAfter(op), getProgramPointAfter(callable));
+  if (callsites->allPredecessorsKnown()) {
+    for (Operation *call : callsites->getKnownPredecessors()) {
+      SmallVector<const AbstractSparseLattice *> callResultLattices =
+          getLatticeElementsFor(getProgramPointAfter(op), call->getResults());
+      for (auto [op, result] : llvm::zip(operandLattices, callResultLattices))
+        meet(op, *result);
+    }
+  } else {
+    // If we don't know all the callers, we can't know where the
+    // returned values go. Note that, in particular, this will trigger
+    // for the return ops of any public functions.
+    setAllToExitStates(operandLattices);
+  }
+  return success();
+}
+
 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
     RegionBranchOpInterface branch,
     ArrayRef<AbstractSparseLattice *> operandLattices) {


        


More information about the Mlir-commits mailing list