[Mlir-commits] [mlir] [MLIR][LivenessAnalysis] Add custom visitCallOperation (PR #160984)

xin liu llvmlistbot at llvm.org
Fri Sep 26 22:58:15 PDT 2025


https://github.com/navyxliu created https://github.com/llvm/llvm-project/pull/160984

This diff add virtual function visitCallOperation to SparseBackwardDataFlowAnalysis. This allows Liveness Analysis hook custom logic for the public function. 

>From 7673159670e89527d01e099dd166466477e5c920 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Fri, 26 Sep 2025 21:09:13 -0700
Subject: [PATCH 1/2] [MLIR][SparseAnalysis] Add virtual function callOperation
 to AbstractSparseBackwardDataFlowAnalysis

---
 .../mlir/Analysis/DataFlow/SparseAnalysis.h   | 15 ++++
 mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 87 +++++++++++--------
 2 files changed, 66 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 3f8874d02afad..eb7fc8e698743 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -464,6 +464,21 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   visitCallableOperation(Operation *op, CallableOpInterface callable,
                          ArrayRef<AbstractSparseLattice *> operandLattices);
 
+  /// Visits a call operation. Given the result lattices, set the operand lattices.
+  //  Performs interprocedural data flow as follows: if the call
+  /// operation targets an external function, or if the solver is not
+  /// interprocedural, attempts to infer the results from the call arguments
+  /// using the user-provided `visitExternalCallImpl`. Otherwise, computes the
+  /// result lattices from the return sites if all return sites are known;
+  /// otherwise, conservatively marks the result lattices as having reached
+  /// their pessimistic fixpoints.
+  /// This method can be overridden to, for example, be less conservative and
+  /// propagate the information even if some return sites are unknown.
+  virtual LogicalResult
+  visitCallOperation(CallOpInterface call,
+                     ArrayRef<AbstractSparseLattice *> operandLattices,
+                     ArrayRef<const AbstractSparseLattice *> resultLattices);
+
 private:
   /// Recursively initialize the analysis on nested operations and blocks.
   LogicalResult initializeRecursively(Operation *op);
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 0d2e2ed85549d..622fd8ecfb138 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -495,42 +495,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // For function calls, connect the arguments of the entry blocks to the
   // operands of the call op that are forwarded to these arguments.
   if (auto call = dyn_cast<CallOpInterface>(op)) {
-    LDBG() << "Processing CallOpInterface operation";
-    Operation *callableOp = call.resolveCallableInTable(&symbolTable);
-    if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
-      // Not all operands of a call op forward to arguments. Such operands are
-      // stored in `unaccounted`.
-      BitVector unaccounted(op->getNumOperands(), true);
-
-      // If the call invokes an external function (or a function treated as
-      // external due to config), defer to the corresponding extension hook.
-      // By default, it just does `visitCallOperand` for all operands.
-      OperandRange argOperands = call.getArgOperands();
-      MutableArrayRef<OpOperand> argOpOperands =
-          operandsToOpOperands(argOperands);
-      Region *region = callable.getCallableRegion();
-      if (!region || region->empty() ||
-          !getSolverConfig().isInterprocedural()) {
-        visitExternalCallImpl(call, operandLattices, resultLattices);
-        return success();
-      }
-
-      // Otherwise, propagate information from the entry point of the function
-      // back to operands whenever possible.
-      Block &block = region->front();
-      for (auto [blockArg, argOpOperand] :
-           llvm::zip(block.getArguments(), argOpOperands)) {
-        meet(getLatticeElement(argOpOperand.get()),
-             *getLatticeElementFor(getProgramPointAfter(op), blockArg));
-        unaccounted.reset(argOpOperand.getOperandNumber());
-      }
-
-      // Handle the operands of the call op that aren't forwarded to any
-      // arguments.
-      for (int index : unaccounted.set_bits()) {
-        OpOperand &opOperand = op->getOpOperand(index);
-        visitCallOperand(opOperand);
-      }
+    if (visitCallOperation(call, operandLattices, resultLattices).succeeded()) {
       return success();
     }
   }
@@ -588,6 +553,56 @@ LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
   return success();
 }
 
+LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallOperation(
+    CallOpInterface call,
+    ArrayRef<AbstractSparseLattice *> operandLattices,
+    ArrayRef<const AbstractSparseLattice *> resultLattices) {
+    LDBG() << "Processing CallOpInterface operation";
+
+    //  Treat any function as external due to config.
+    if (!getSolverConfig().isInterprocedural()) {
+        visitExternalCallImpl(call, operandLattices, resultLattices);
+        return success();
+    }
+
+    Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+    if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
+      // Not all operands of a call op forward to arguments. Such operands are
+      // stored in `unaccounted`.
+      BitVector unaccounted(call->getNumOperands(), true);
+
+      // If the call invokes an external function, defer to the corresponding extension hook.
+      // By default, it just does `visitCallOperand` for all operands.
+      OperandRange argOperands = call.getArgOperands();
+      MutableArrayRef<OpOperand> argOpOperands =
+          operandsToOpOperands(argOperands);
+      Region *region = callable.getCallableRegion();
+
+      if (!region || region->empty()) {
+        visitExternalCallImpl(call, operandLattices, resultLattices);
+        return success();
+      }
+      // Otherwise, propagate information from the entry point of the function
+      // back to operands whenever possible.
+      Block &block = region->front();
+      for (auto [blockArg, argOpOperand] :
+           llvm::zip(block.getArguments(), argOpOperands)) {
+        meet(getLatticeElement(argOpOperand.get()),
+             *getLatticeElementFor(getProgramPointAfter(call), blockArg));
+        unaccounted.reset(argOpOperand.getOperandNumber());
+      }
+
+      // Handle the operands of the call op that aren't forwarded to any
+      // arguments.
+      for (int index : unaccounted.set_bits()) {
+        OpOperand &opOperand = call->getOpOperand(index);
+        visitCallOperand(opOperand);
+      }
+      return success();
+    }
+    return failure();
+}
+
 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
     RegionBranchOpInterface branch,
     ArrayRef<AbstractSparseLattice *> operandLattices) {

>From b14d2d585d077594a41210c2eb7824b50cbd299e Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Fri, 26 Sep 2025 22:28:29 -0700
Subject: [PATCH 2/2] Define custom visitCallOperation for liveness Analysis.

---
 .../mlir/Analysis/DataFlow/LivenessAnalysis.h |  5 +
 .../mlir/Analysis/DataFlow/SparseAnalysis.h   | 39 +++++---
 .../Analysis/DataFlow/LivenessAnalysis.cpp    | 22 +++++
 mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 95 +++++++++----------
 4 files changed, 95 insertions(+), 66 deletions(-)

diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index cf1fd6e2d48ca..e13822f064999 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -82,6 +82,11 @@ class LivenessAnalysis : public SparseBackwardDataFlowAnalysis<Liveness> {
   LogicalResult visitOperation(Operation *op, ArrayRef<Liveness *> operands,
                                ArrayRef<const Liveness *> results) override;
 
+  LogicalResult
+  visitCallOperation(CallOpInterface call,
+                     ArrayRef<Liveness *> operandLattices,
+                     ArrayRef<const Liveness *> resultLattices) override;
+
   void visitBranchOperand(OpOperand &operand) override;
 
   void visitCallOperand(OpOperand &operand) override;
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index eb7fc8e698743..53aa4d03f1c41 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -431,6 +431,11 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
       CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
       ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
 
+  /// The transfer function for call operations.
+  virtual LogicalResult visitCallOperationImpl(
+      CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
+      ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
+
   // Visit operands on branch instructions that are not forwarded.
   virtual void visitBranchOperand(OpOperand &operand) = 0;
 
@@ -464,21 +469,6 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   visitCallableOperation(Operation *op, CallableOpInterface callable,
                          ArrayRef<AbstractSparseLattice *> operandLattices);
 
-  /// Visits a call operation. Given the result lattices, set the operand lattices.
-  //  Performs interprocedural data flow as follows: if the call
-  /// operation targets an external function, or if the solver is not
-  /// interprocedural, attempts to infer the results from the call arguments
-  /// using the user-provided `visitExternalCallImpl`. Otherwise, computes the
-  /// result lattices from the return sites if all return sites are known;
-  /// otherwise, conservatively marks the result lattices as having reached
-  /// their pessimistic fixpoints.
-  /// This method can be overridden to, for example, be less conservative and
-  /// propagate the information even if some return sites are unknown.
-  virtual LogicalResult
-  visitCallOperation(CallOpInterface call,
-                     ArrayRef<AbstractSparseLattice *> operandLattices,
-                     ArrayRef<const AbstractSparseLattice *> resultLattices);
-
 private:
   /// Recursively initialize the analysis on nested operations and blocks.
   LogicalResult initializeRecursively(Operation *op);
@@ -515,6 +505,7 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
   SmallVector<const AbstractSparseLattice *>
   getLatticeElementsFor(ProgramPoint *point, ValueRange values);
 
+protected:
   SymbolTableCollection &symbolTable;
 };
 
@@ -544,6 +535,13 @@ class SparseBackwardDataFlowAnalysis
                                        ArrayRef<StateT *> operands,
                                        ArrayRef<const StateT *> results) = 0;
 
+  // implementation-specific hook for visitCallOperation.
+  virtual LogicalResult visitCallOperation(CallOpInterface call,
+                                       ArrayRef<StateT *> operands,
+                                       ArrayRef<const StateT *> results) {
+                                        return failure();
+                                       }
+
   /// Visit a call to an external function. This function is expected to set
   /// lattice values of the call operands. By default, calls `visitCallOperand`
   /// for all operands.
@@ -598,6 +596,17 @@ class SparseBackwardDataFlowAnalysis
         {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
          resultLattices.size()});
   }
+  LogicalResult visitCallOperationImpl(
+      CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
+      ArrayRef<const AbstractSparseLattice *> resultLattices) override {
+    return visitCallOperation(
+        call,
+        {reinterpret_cast<StateT *const *>(operandLattices.begin()),
+         operandLattices.size()},
+        {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
+         resultLattices.size()});
+  }
+
 };
 
 } // end namespace dataflow
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index fdb97d5963299..6b1809bc8f0d3 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include <cassert>
 #include <mlir/Analysis/DataFlow/LivenessAnalysis.h>
 
@@ -115,6 +116,27 @@ LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands,
   return success();
 }
 
+LogicalResult
+LivenessAnalysis::visitCallOperation(CallOpInterface call, ArrayRef<Liveness *> operands,
+                                 ArrayRef<const Liveness *> results) {
+  LDBG() << "[visitCallOperation] inspecting " <<  call;
+  // For thread-safety, check config first before accessing symbolTable.
+  if (!getSolverConfig().isInterprocedural()) {
+      visitExternalCall(call, operands, results);
+      return success();
+  }
+  Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+  if (auto funcOp = dyn_cast<FunctionOpInterface>(callableOp)) {
+    if (funcOp.isPublic()) {
+      LDBG() << "[visitCallOperation] encounter a public function " <<  funcOp
+             << " Treat it as external.";
+      visitExternalCall(call, operands, results);
+      return success();
+    }
+  }
+  return failure();
+}
+
 void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
   LDBG() << "Visiting branch operand: " << operand.get()
          << " in op: " << *operand.getOwner();
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 622fd8ecfb138..c7169bf977f9b 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -495,7 +495,50 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
   // For function calls, connect the arguments of the entry blocks to the
   // operands of the call op that are forwarded to these arguments.
   if (auto call = dyn_cast<CallOpInterface>(op)) {
-    if (visitCallOperation(call, operandLattices, resultLattices).succeeded()) {
+    LDBG() << "Processing CallOpInterface operation";
+
+    if (visitCallOperationImpl(call, operandLattices, resultLattices).succeeded()) {
+      return success();
+    }
+
+    //  Treat any function as external due to config.
+    if (!getSolverConfig().isInterprocedural()) {
+        visitExternalCallImpl(call, operandLattices, resultLattices);
+        return success();
+    }
+
+    Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+    if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
+      // Not all operands of a call op forward to arguments. Such operands are
+      // stored in `unaccounted`.
+      BitVector unaccounted(op->getNumOperands(), true);
+
+      // If the call invokes an external function, defer to the corresponding extension hook.
+      // By default, it just does `visitCallOperand` for all operands.
+      OperandRange argOperands = call.getArgOperands();
+      MutableArrayRef<OpOperand> argOpOperands =
+          operandsToOpOperands(argOperands);
+      Region *region = callable.getCallableRegion();
+      if (!region || region->empty()) {
+        visitExternalCallImpl(call, operandLattices, resultLattices);
+        return success();
+      }
+      // Otherwise, propagate information from the entry point of the function
+      // back to operands whenever possible.
+      Block &block = region->front();
+      for (auto [blockArg, argOpOperand] :
+           llvm::zip(block.getArguments(), argOpOperands)) {
+        meet(getLatticeElement(argOpOperand.get()),
+             *getLatticeElementFor(getProgramPointAfter(op), blockArg));
+        unaccounted.reset(argOpOperand.getOperandNumber());
+      }
+
+      // Handle the operands of the call op that aren't forwarded to any
+      // arguments.
+      for (int index : unaccounted.set_bits()) {
+        OpOperand &opOperand = op->getOpOperand(index);
+        visitCallOperand(opOperand);
+      }
       return success();
     }
   }
@@ -553,56 +596,6 @@ LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
   return success();
 }
 
-LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallOperation(
-    CallOpInterface call,
-    ArrayRef<AbstractSparseLattice *> operandLattices,
-    ArrayRef<const AbstractSparseLattice *> resultLattices) {
-    LDBG() << "Processing CallOpInterface operation";
-
-    //  Treat any function as external due to config.
-    if (!getSolverConfig().isInterprocedural()) {
-        visitExternalCallImpl(call, operandLattices, resultLattices);
-        return success();
-    }
-
-    Operation *callableOp = call.resolveCallableInTable(&symbolTable);
-    if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
-      // Not all operands of a call op forward to arguments. Such operands are
-      // stored in `unaccounted`.
-      BitVector unaccounted(call->getNumOperands(), true);
-
-      // If the call invokes an external function, defer to the corresponding extension hook.
-      // By default, it just does `visitCallOperand` for all operands.
-      OperandRange argOperands = call.getArgOperands();
-      MutableArrayRef<OpOperand> argOpOperands =
-          operandsToOpOperands(argOperands);
-      Region *region = callable.getCallableRegion();
-
-      if (!region || region->empty()) {
-        visitExternalCallImpl(call, operandLattices, resultLattices);
-        return success();
-      }
-      // Otherwise, propagate information from the entry point of the function
-      // back to operands whenever possible.
-      Block &block = region->front();
-      for (auto [blockArg, argOpOperand] :
-           llvm::zip(block.getArguments(), argOpOperands)) {
-        meet(getLatticeElement(argOpOperand.get()),
-             *getLatticeElementFor(getProgramPointAfter(call), blockArg));
-        unaccounted.reset(argOpOperand.getOperandNumber());
-      }
-
-      // Handle the operands of the call op that aren't forwarded to any
-      // arguments.
-      for (int index : unaccounted.set_bits()) {
-        OpOperand &opOperand = call->getOpOperand(index);
-        visitCallOperand(opOperand);
-      }
-      return success();
-    }
-    return failure();
-}
-
 void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
     RegionBranchOpInterface branch,
     ArrayRef<AbstractSparseLattice *> operandLattices) {



More information about the Mlir-commits mailing list