[Mlir-commits] [mlir] [MLIR][LivenessAnalysis] Add custom visitCallOperation (PR #160984)
xin liu
llvmlistbot at llvm.org
Fri Sep 26 22:58:15 PDT 2025
https://github.com/navyxliu created https://github.com/llvm/llvm-project/pull/160984
This diff add virtual function visitCallOperation to SparseBackwardDataFlowAnalysis. This allows Liveness Analysis hook custom logic for the public function.
>From 7673159670e89527d01e099dd166466477e5c920 Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Fri, 26 Sep 2025 21:09:13 -0700
Subject: [PATCH 1/2] [MLIR][SparseAnalysis] Add virtual function callOperation
to AbstractSparseBackwardDataFlowAnalysis
---
.../mlir/Analysis/DataFlow/SparseAnalysis.h | 15 ++++
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 87 +++++++++++--------
2 files changed, 66 insertions(+), 36 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 3f8874d02afad..eb7fc8e698743 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -464,6 +464,21 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
visitCallableOperation(Operation *op, CallableOpInterface callable,
ArrayRef<AbstractSparseLattice *> operandLattices);
+ /// Visits a call operation. Given the result lattices, set the operand lattices.
+ // Performs interprocedural data flow as follows: if the call
+ /// operation targets an external function, or if the solver is not
+ /// interprocedural, attempts to infer the results from the call arguments
+ /// using the user-provided `visitExternalCallImpl`. Otherwise, computes the
+ /// result lattices from the return sites if all return sites are known;
+ /// otherwise, conservatively marks the result lattices as having reached
+ /// their pessimistic fixpoints.
+ /// This method can be overridden to, for example, be less conservative and
+ /// propagate the information even if some return sites are unknown.
+ virtual LogicalResult
+ visitCallOperation(CallOpInterface call,
+ ArrayRef<AbstractSparseLattice *> operandLattices,
+ ArrayRef<const AbstractSparseLattice *> resultLattices);
+
private:
/// Recursively initialize the analysis on nested operations and blocks.
LogicalResult initializeRecursively(Operation *op);
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 0d2e2ed85549d..622fd8ecfb138 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -495,42 +495,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// For function calls, connect the arguments of the entry blocks to the
// operands of the call op that are forwarded to these arguments.
if (auto call = dyn_cast<CallOpInterface>(op)) {
- LDBG() << "Processing CallOpInterface operation";
- Operation *callableOp = call.resolveCallableInTable(&symbolTable);
- if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
- // Not all operands of a call op forward to arguments. Such operands are
- // stored in `unaccounted`.
- BitVector unaccounted(op->getNumOperands(), true);
-
- // If the call invokes an external function (or a function treated as
- // external due to config), defer to the corresponding extension hook.
- // By default, it just does `visitCallOperand` for all operands.
- OperandRange argOperands = call.getArgOperands();
- MutableArrayRef<OpOperand> argOpOperands =
- operandsToOpOperands(argOperands);
- Region *region = callable.getCallableRegion();
- if (!region || region->empty() ||
- !getSolverConfig().isInterprocedural()) {
- visitExternalCallImpl(call, operandLattices, resultLattices);
- return success();
- }
-
- // Otherwise, propagate information from the entry point of the function
- // back to operands whenever possible.
- Block &block = region->front();
- for (auto [blockArg, argOpOperand] :
- llvm::zip(block.getArguments(), argOpOperands)) {
- meet(getLatticeElement(argOpOperand.get()),
- *getLatticeElementFor(getProgramPointAfter(op), blockArg));
- unaccounted.reset(argOpOperand.getOperandNumber());
- }
-
- // Handle the operands of the call op that aren't forwarded to any
- // arguments.
- for (int index : unaccounted.set_bits()) {
- OpOperand &opOperand = op->getOpOperand(index);
- visitCallOperand(opOperand);
- }
+ if (visitCallOperation(call, operandLattices, resultLattices).succeeded()) {
return success();
}
}
@@ -588,6 +553,56 @@ LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
return success();
}
+LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallOperation(
+ CallOpInterface call,
+ ArrayRef<AbstractSparseLattice *> operandLattices,
+ ArrayRef<const AbstractSparseLattice *> resultLattices) {
+ LDBG() << "Processing CallOpInterface operation";
+
+ // Treat any function as external due to config.
+ if (!getSolverConfig().isInterprocedural()) {
+ visitExternalCallImpl(call, operandLattices, resultLattices);
+ return success();
+ }
+
+ Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+ if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
+ // Not all operands of a call op forward to arguments. Such operands are
+ // stored in `unaccounted`.
+ BitVector unaccounted(call->getNumOperands(), true);
+
+ // If the call invokes an external function, defer to the corresponding extension hook.
+ // By default, it just does `visitCallOperand` for all operands.
+ OperandRange argOperands = call.getArgOperands();
+ MutableArrayRef<OpOperand> argOpOperands =
+ operandsToOpOperands(argOperands);
+ Region *region = callable.getCallableRegion();
+
+ if (!region || region->empty()) {
+ visitExternalCallImpl(call, operandLattices, resultLattices);
+ return success();
+ }
+ // Otherwise, propagate information from the entry point of the function
+ // back to operands whenever possible.
+ Block &block = region->front();
+ for (auto [blockArg, argOpOperand] :
+ llvm::zip(block.getArguments(), argOpOperands)) {
+ meet(getLatticeElement(argOpOperand.get()),
+ *getLatticeElementFor(getProgramPointAfter(call), blockArg));
+ unaccounted.reset(argOpOperand.getOperandNumber());
+ }
+
+ // Handle the operands of the call op that aren't forwarded to any
+ // arguments.
+ for (int index : unaccounted.set_bits()) {
+ OpOperand &opOperand = call->getOpOperand(index);
+ visitCallOperand(opOperand);
+ }
+ return success();
+ }
+ return failure();
+}
+
void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
RegionBranchOpInterface branch,
ArrayRef<AbstractSparseLattice *> operandLattices) {
>From b14d2d585d077594a41210c2eb7824b50cbd299e Mon Sep 17 00:00:00 2001
From: Xin Liu <xxinliu at meta.com>
Date: Fri, 26 Sep 2025 22:28:29 -0700
Subject: [PATCH 2/2] Define custom visitCallOperation for liveness Analysis.
---
.../mlir/Analysis/DataFlow/LivenessAnalysis.h | 5 +
.../mlir/Analysis/DataFlow/SparseAnalysis.h | 39 +++++---
.../Analysis/DataFlow/LivenessAnalysis.cpp | 22 +++++
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 95 +++++++++----------
4 files changed, 95 insertions(+), 66 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
index cf1fd6e2d48ca..e13822f064999 100644
--- a/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/LivenessAnalysis.h
@@ -82,6 +82,11 @@ class LivenessAnalysis : public SparseBackwardDataFlowAnalysis<Liveness> {
LogicalResult visitOperation(Operation *op, ArrayRef<Liveness *> operands,
ArrayRef<const Liveness *> results) override;
+ LogicalResult
+ visitCallOperation(CallOpInterface call,
+ ArrayRef<Liveness *> operandLattices,
+ ArrayRef<const Liveness *> resultLattices) override;
+
void visitBranchOperand(OpOperand &operand) override;
void visitCallOperand(OpOperand &operand) override;
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index eb7fc8e698743..53aa4d03f1c41 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -431,6 +431,11 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
+ /// The transfer function for call operations.
+ virtual LogicalResult visitCallOperationImpl(
+ CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
+ ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
+
// Visit operands on branch instructions that are not forwarded.
virtual void visitBranchOperand(OpOperand &operand) = 0;
@@ -464,21 +469,6 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
visitCallableOperation(Operation *op, CallableOpInterface callable,
ArrayRef<AbstractSparseLattice *> operandLattices);
- /// Visits a call operation. Given the result lattices, set the operand lattices.
- // Performs interprocedural data flow as follows: if the call
- /// operation targets an external function, or if the solver is not
- /// interprocedural, attempts to infer the results from the call arguments
- /// using the user-provided `visitExternalCallImpl`. Otherwise, computes the
- /// result lattices from the return sites if all return sites are known;
- /// otherwise, conservatively marks the result lattices as having reached
- /// their pessimistic fixpoints.
- /// This method can be overridden to, for example, be less conservative and
- /// propagate the information even if some return sites are unknown.
- virtual LogicalResult
- visitCallOperation(CallOpInterface call,
- ArrayRef<AbstractSparseLattice *> operandLattices,
- ArrayRef<const AbstractSparseLattice *> resultLattices);
-
private:
/// Recursively initialize the analysis on nested operations and blocks.
LogicalResult initializeRecursively(Operation *op);
@@ -515,6 +505,7 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
SmallVector<const AbstractSparseLattice *>
getLatticeElementsFor(ProgramPoint *point, ValueRange values);
+protected:
SymbolTableCollection &symbolTable;
};
@@ -544,6 +535,13 @@ class SparseBackwardDataFlowAnalysis
ArrayRef<StateT *> operands,
ArrayRef<const StateT *> results) = 0;
+ // implementation-specific hook for visitCallOperation.
+ virtual LogicalResult visitCallOperation(CallOpInterface call,
+ ArrayRef<StateT *> operands,
+ ArrayRef<const StateT *> results) {
+ return failure();
+ }
+
/// Visit a call to an external function. This function is expected to set
/// lattice values of the call operands. By default, calls `visitCallOperand`
/// for all operands.
@@ -598,6 +596,17 @@ class SparseBackwardDataFlowAnalysis
{reinterpret_cast<const StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
+ LogicalResult visitCallOperationImpl(
+ CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
+ ArrayRef<const AbstractSparseLattice *> resultLattices) override {
+ return visitCallOperation(
+ call,
+ {reinterpret_cast<StateT *const *>(operandLattices.begin()),
+ operandLattices.size()},
+ {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
+ resultLattices.size()});
+ }
+
};
} // end namespace dataflow
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index fdb97d5963299..6b1809bc8f0d3 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include <cassert>
#include <mlir/Analysis/DataFlow/LivenessAnalysis.h>
@@ -115,6 +116,27 @@ LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands,
return success();
}
+LogicalResult
+LivenessAnalysis::visitCallOperation(CallOpInterface call, ArrayRef<Liveness *> operands,
+ ArrayRef<const Liveness *> results) {
+ LDBG() << "[visitCallOperation] inspecting " << call;
+ // For thread-safety, check config first before accessing symbolTable.
+ if (!getSolverConfig().isInterprocedural()) {
+ visitExternalCall(call, operands, results);
+ return success();
+ }
+ Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(callableOp)) {
+ if (funcOp.isPublic()) {
+ LDBG() << "[visitCallOperation] encounter a public function " << funcOp
+ << " Treat it as external.";
+ visitExternalCall(call, operands, results);
+ return success();
+ }
+ }
+ return failure();
+}
+
void LivenessAnalysis::visitBranchOperand(OpOperand &operand) {
LDBG() << "Visiting branch operand: " << operand.get()
<< " in op: " << *operand.getOwner();
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 622fd8ecfb138..c7169bf977f9b 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -495,7 +495,50 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// For function calls, connect the arguments of the entry blocks to the
// operands of the call op that are forwarded to these arguments.
if (auto call = dyn_cast<CallOpInterface>(op)) {
- if (visitCallOperation(call, operandLattices, resultLattices).succeeded()) {
+ LDBG() << "Processing CallOpInterface operation";
+
+ if (visitCallOperationImpl(call, operandLattices, resultLattices).succeeded()) {
+ return success();
+ }
+
+ // Treat any function as external due to config.
+ if (!getSolverConfig().isInterprocedural()) {
+ visitExternalCallImpl(call, operandLattices, resultLattices);
+ return success();
+ }
+
+ Operation *callableOp = call.resolveCallableInTable(&symbolTable);
+ if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
+ // Not all operands of a call op forward to arguments. Such operands are
+ // stored in `unaccounted`.
+ BitVector unaccounted(op->getNumOperands(), true);
+
+ // If the call invokes an external function, defer to the corresponding extension hook.
+ // By default, it just does `visitCallOperand` for all operands.
+ OperandRange argOperands = call.getArgOperands();
+ MutableArrayRef<OpOperand> argOpOperands =
+ operandsToOpOperands(argOperands);
+ Region *region = callable.getCallableRegion();
+ if (!region || region->empty()) {
+ visitExternalCallImpl(call, operandLattices, resultLattices);
+ return success();
+ }
+ // Otherwise, propagate information from the entry point of the function
+ // back to operands whenever possible.
+ Block &block = region->front();
+ for (auto [blockArg, argOpOperand] :
+ llvm::zip(block.getArguments(), argOpOperands)) {
+ meet(getLatticeElement(argOpOperand.get()),
+ *getLatticeElementFor(getProgramPointAfter(op), blockArg));
+ unaccounted.reset(argOpOperand.getOperandNumber());
+ }
+
+ // Handle the operands of the call op that aren't forwarded to any
+ // arguments.
+ for (int index : unaccounted.set_bits()) {
+ OpOperand &opOperand = op->getOpOperand(index);
+ visitCallOperand(opOperand);
+ }
return success();
}
}
@@ -553,56 +596,6 @@ LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
return success();
}
-LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallOperation(
- CallOpInterface call,
- ArrayRef<AbstractSparseLattice *> operandLattices,
- ArrayRef<const AbstractSparseLattice *> resultLattices) {
- LDBG() << "Processing CallOpInterface operation";
-
- // Treat any function as external due to config.
- if (!getSolverConfig().isInterprocedural()) {
- visitExternalCallImpl(call, operandLattices, resultLattices);
- return success();
- }
-
- Operation *callableOp = call.resolveCallableInTable(&symbolTable);
- if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
- // Not all operands of a call op forward to arguments. Such operands are
- // stored in `unaccounted`.
- BitVector unaccounted(call->getNumOperands(), true);
-
- // If the call invokes an external function, defer to the corresponding extension hook.
- // By default, it just does `visitCallOperand` for all operands.
- OperandRange argOperands = call.getArgOperands();
- MutableArrayRef<OpOperand> argOpOperands =
- operandsToOpOperands(argOperands);
- Region *region = callable.getCallableRegion();
-
- if (!region || region->empty()) {
- visitExternalCallImpl(call, operandLattices, resultLattices);
- return success();
- }
- // Otherwise, propagate information from the entry point of the function
- // back to operands whenever possible.
- Block &block = region->front();
- for (auto [blockArg, argOpOperand] :
- llvm::zip(block.getArguments(), argOpOperands)) {
- meet(getLatticeElement(argOpOperand.get()),
- *getLatticeElementFor(getProgramPointAfter(call), blockArg));
- unaccounted.reset(argOpOperand.getOperandNumber());
- }
-
- // Handle the operands of the call op that aren't forwarded to any
- // arguments.
- for (int index : unaccounted.set_bits()) {
- OpOperand &opOperand = call->getOpOperand(index);
- visitCallOperand(opOperand);
- }
- return success();
- }
- return failure();
-}
-
void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
RegionBranchOpInterface branch,
ArrayRef<AbstractSparseLattice *> operandLattices) {
More information about the Mlir-commits
mailing list