[Mlir-commits] [mlir] 5a531b1 - [mlir] NFC: Add data flow analysis extension points (#142549)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 4 05:15:09 PDT 2025
Author: Vadim Curcă
Date: 2025-06-04T14:15:05+02:00
New Revision: 5a531b115844a038d7bd0108ebafe6bacbef75e3
URL: https://github.com/llvm/llvm-project/commit/5a531b115844a038d7bd0108ebafe6bacbef75e3
DIFF: https://github.com/llvm/llvm-project/commit/5a531b115844a038d7bd0108ebafe6bacbef75e3.diff
LOG: [mlir] NFC: Add data flow analysis extension points (#142549)
This commit introduces `visitCallOperation` and `visitCallableOperation`
extension points in the sparse data flow analysis framework. This
allows, for example, to make the analysis less conservative, without a
lot of code duplication, propagating information even if not all the
call or return sites are known.
Added:
Modified:
mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 1b2c679176107..3f8874d02afad 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -235,6 +235,30 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// Join the lattice element and propagate and update if it changed.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
+ /// Visits a call operation. Given the operand lattices, sets the result
+ /// lattices. Performs interprocedural data flow as follows: if the call
+ /// operation targets an external function, or if the solver is not
+ /// interprocedural, attempts to infer the results from the call arguments
+ /// using the user-provided `visitExternalCallImpl`. Otherwise, computes the
+ /// result lattices from the return sites if all return sites are known;
+ /// otherwise, conservatively marks the result lattices as having reached
+ /// their pessimistic fixpoints.
+ /// This method can be overridden to, for example, be less conservative and
+ /// propagate the information even if some return sites are unknown.
+ virtual LogicalResult
+ visitCallOperation(CallOpInterface call,
+ ArrayRef<const AbstractSparseLattice *> operandLattices,
+ ArrayRef<AbstractSparseLattice *> resultLattices);
+
+ /// Visits a callable operation. Computes the argument lattices from call
+ /// sites if all call sites are known; otherwise, conservatively marks them
+ /// as having reached their pessimistic fixpoints.
+ /// This method can be overridden to, for example, be less conservative and
+ /// propagate the information even if some call sites are unknown.
+ virtual void
+ visitCallableOperation(CallableOpInterface callable,
+ ArrayRef<AbstractSparseLattice *> argLattices);
+
private:
/// Recursively initialize the analysis on nested operations and blocks.
LogicalResult initializeRecursively(Operation *op);
@@ -430,6 +454,16 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// Join the lattice element and propagate and update if it changed.
void meet(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs);
+ /// Visits a callable operation. If all the call sites are known computes the
+ /// operand lattices of `op` from the result lattices of all the call sites;
+ /// otherwise, conservatively marks them as having reached their pessimistic
+ /// fixpoints.
+ /// This method can be overridden to, for example, be less conservative and
+ /// propagate the information even if some call sites are unknown.
+ virtual LogicalResult
+ visitCallableOperation(Operation *op, CallableOpInterface callable,
+ ArrayRef<AbstractSparseLattice *> operandLattices);
+
private:
/// Recursively initialize the analysis on nested operations and blocks.
LogicalResult initializeRecursively(Operation *op);
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 0b39d14042493..016e59dcb744e 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -128,34 +128,8 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
operandLattices.push_back(operandLattice);
}
- if (auto call = dyn_cast<CallOpInterface>(op)) {
- // If the call operation is to an external function, attempt to infer the
- // results from the call arguments.
- auto callable =
- dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
- if (!getSolverConfig().isInterprocedural() ||
- (callable && !callable.getCallableRegion())) {
- visitExternalCallImpl(call, operandLattices, resultLattices);
- return success();
- }
-
- // Otherwise, the results of a call operation are determined by the
- // callgraph.
- const auto *predecessors = getOrCreateFor<PredecessorState>(
- getProgramPointAfter(op), getProgramPointAfter(call));
- // If not all return sites are known, then conservatively assume we can't
- // reason about the data-flow.
- if (!predecessors->allPredecessorsKnown()) {
- setAllToEntryStates(resultLattices);
- return success();
- }
- for (Operation *predecessor : predecessors->getKnownPredecessors())
- for (auto &&[operand, resLattice] :
- llvm::zip(predecessor->getOperands(), resultLattices))
- join(resLattice,
- *getLatticeElementFor(getProgramPointAfter(op), operand));
- return success();
- }
+ if (auto call = dyn_cast<CallOpInterface>(op))
+ return visitCallOperation(call, operandLattices, resultLattices);
// Invoke the operation transfer function.
return visitOperationImpl(op, operandLattices, resultLattices);
@@ -183,24 +157,8 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
if (block->isEntryBlock()) {
// Check if this block is the entry block of a callable region.
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
- if (callable && callable.getCallableRegion() == block->getParent()) {
- const auto *callsites = getOrCreateFor<PredecessorState>(
- getProgramPointBefore(block), getProgramPointAfter(callable));
- // If not all callsites are known, conservatively mark all lattices as
- // having reached their pessimistic fixpoints.
- if (!callsites->allPredecessorsKnown() ||
- !getSolverConfig().isInterprocedural()) {
- return setAllToEntryStates(argLattices);
- }
- for (Operation *callsite : callsites->getKnownPredecessors()) {
- auto call = cast<CallOpInterface>(callsite);
- for (auto it : llvm::zip(call.getArgOperands(), argLattices))
- join(std::get<1>(it),
- *getLatticeElementFor(getProgramPointBefore(block),
- std::get<0>(it)));
- }
- return;
- }
+ if (callable && callable.getCallableRegion() == block->getParent())
+ return visitCallableOperation(callable, argLattices);
// Check if the lattices can be determined from region control flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
@@ -248,6 +206,59 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
}
}
+LogicalResult AbstractSparseForwardDataFlowAnalysis::visitCallOperation(
+ CallOpInterface call,
+ ArrayRef<const AbstractSparseLattice *> operandLattices,
+ ArrayRef<AbstractSparseLattice *> resultLattices) {
+ // If the call operation is to an external function, attempt to infer the
+ // results from the call arguments.
+ auto callable =
+ dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+ if (!getSolverConfig().isInterprocedural() ||
+ (callable && !callable.getCallableRegion())) {
+ visitExternalCallImpl(call, operandLattices, resultLattices);
+ return success();
+ }
+
+ // Otherwise, the results of a call operation are determined by the
+ // callgraph.
+ const auto *predecessors = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(call), getProgramPointAfter(call));
+ // If not all return sites are known, then conservatively assume we can't
+ // reason about the data-flow.
+ if (!predecessors->allPredecessorsKnown()) {
+ setAllToEntryStates(resultLattices);
+ return success();
+ }
+ for (Operation *predecessor : predecessors->getKnownPredecessors())
+ for (auto &&[operand, resLattice] :
+ llvm::zip(predecessor->getOperands(), resultLattices))
+ join(resLattice,
+ *getLatticeElementFor(getProgramPointAfter(call), operand));
+ return success();
+}
+
+void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
+ CallableOpInterface callable,
+ ArrayRef<AbstractSparseLattice *> argLattices) {
+ Block *entryBlock = &callable.getCallableRegion()->front();
+ const auto *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointBefore(entryBlock), getProgramPointAfter(callable));
+ // If not all callsites are known, conservatively mark all lattices as
+ // having reached their pessimistic fixpoints.
+ if (!callsites->allPredecessorsKnown() ||
+ !getSolverConfig().isInterprocedural()) {
+ return setAllToEntryStates(argLattices);
+ }
+ for (Operation *callsite : callsites->getKnownPredecessors()) {
+ auto call = cast<CallOpInterface>(callsite);
+ for (auto it : llvm::zip(call.getArgOperands(), argLattices))
+ join(std::get<1>(it),
+ *getLatticeElementFor(getProgramPointBefore(entryBlock),
+ std::get<0>(it)));
+ }
+}
+
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint *point, RegionBranchOpInterface branch,
RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
@@ -512,31 +523,34 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
if (op->hasTrait<OpTrait::ReturnLike>()) {
// Going backwards, the operands of the return are derived from the
// results of all CallOps calling this CallableOp.
- if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
- const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
- getProgramPointAfter(op), getProgramPointAfter(callable));
- if (callsites->allPredecessorsKnown()) {
- for (Operation *call : callsites->getKnownPredecessors()) {
- SmallVector<const AbstractSparseLattice *> callResultLattices =
- getLatticeElementsFor(getProgramPointAfter(op),
- call->getResults());
- for (auto [op, result] :
- llvm::zip(operandLattices, callResultLattices))
- meet(op, *result);
- }
- } else {
- // If we don't know all the callers, we can't know where the
- // returned values go. Note that, in particular, this will trigger
- // for the return ops of any public functions.
- setAllToExitStates(operandLattices);
- }
- return success();
- }
+ if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
+ return visitCallableOperation(op, callable, operandLattices);
}
return visitOperationImpl(op, operandLattices, resultLattices);
}
+LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitCallableOperation(
+ Operation *op, CallableOpInterface callable,
+ ArrayRef<AbstractSparseLattice *> operandLattices) {
+ const PredecessorState *callsites = getOrCreateFor<PredecessorState>(
+ getProgramPointAfter(op), getProgramPointAfter(callable));
+ if (callsites->allPredecessorsKnown()) {
+ for (Operation *call : callsites->getKnownPredecessors()) {
+ SmallVector<const AbstractSparseLattice *> callResultLattices =
+ getLatticeElementsFor(getProgramPointAfter(op), call->getResults());
+ for (auto [op, result] : llvm::zip(operandLattices, callResultLattices))
+ meet(op, *result);
+ }
+ } else {
+ // If we don't know all the callers, we can't know where the
+ // returned values go. Note that, in particular, this will trigger
+ // for the return ops of any public functions.
+ setAllToExitStates(operandLattices);
+ }
+ return success();
+}
+
void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
RegionBranchOpInterface branch,
ArrayRef<AbstractSparseLattice *> operandLattices) {
More information about the Mlir-commits
mailing list