[Mlir-commits] [mlir] [mlir] support non-interprocedural dataflow analyses (PR #75583)
Oleksandr Alex Zinenko
llvmlistbot at llvm.org
Fri Dec 15 02:34:02 PST 2023
https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/75583
The core implementation of the dataflow anlysis framework is interpocedural by design. While this offers better analysis precision, it also comes with additional cost as it takes longer for the analysis to reach the fixpoint state. Add a configuration mechanism to the dataflow solver to control whether it operates inteprocedurally or not to offer clients a choice.
As a positive side effect, this change also adds hooks for explicitly processing external/opaque function calls in the dataflow analyses, e.g., based off of attributes present in the the function declaration or call operation such as alias scopes and modref available in the LLVM dialect.
This change should not affect existing analyses and the default solver configuration remains interprocedural.
>From 7b3acd812701e85f614f38a84b4269ad0b144f95 Mon Sep 17 00:00:00 2001
From: Alex Zinenko <zinenko at google.com>
Date: Thu, 14 Dec 2023 14:40:28 +0000
Subject: [PATCH] [mlir] support non-interprocedural dataflow analyses
The core implementation of the dataflow anlysis framework is
interpocedural by design. While this offers better analysis precision,
it also comes with additional cost as it takes longer for the analysis
to reach the fixpoint state. Add a configuration mechanism to the
dataflow solver to control whether it operates inteprocedurally or not
to offer clients a choice.
As a positive side effect, this change also adds hooks for explicitly
processing external/opaque function calls in the dataflow analyses,
e.g., based off of attributes present in the the function declaration or
call operation such as alias scopes and modref available in the LLVM
dialect.
This change should not affect existing analyses and the default solver
configuration remains interprocedural.
Co-authored-by: Jacob Peng <jacobmpeng at gmail.com>
---
.../mlir/Analysis/DataFlow/DenseAnalysis.h | 40 +++--
.../mlir/Analysis/DataFlow/SparseAnalysis.h | 55 +++++++
.../include/mlir/Analysis/DataFlowFramework.h | 38 +++++
mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp | 42 ++++--
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 57 ++++---
.../test-last-modified-callgraph.mlir | 125 +++++++++++++---
.../Analysis/DataFlow/test-next-access.mlir | 140 ++++++++++++++----
.../Analysis/DataFlow/test-written-to.mlir | 90 +++++++++--
.../TestDenseBackwardDataFlowAnalysis.cpp | 103 ++++++++++---
.../DataFlow/TestDenseDataFlowAnalysis.h | 96 +++++++++---
.../TestDenseForwardDataFlowAnalysis.cpp | 107 ++++++++++---
.../TestSparseBackwardDataFlowAnalysis.cpp | 49 +++++-
12 files changed, 771 insertions(+), 171 deletions(-)
diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
index 6a1335bab8bf6e..088b6cd7d698fc 100644
--- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h
@@ -27,8 +27,9 @@ namespace dataflow {
// CallControlFlowAction
//===----------------------------------------------------------------------===//
-/// Indicates whether the control enters or exits the callee.
-enum class CallControlFlowAction { EnterCallee, ExitCallee };
+/// Indicates whether the control enters, exits, or skips over the callee (in
+/// the case of external functions).
+enum class CallControlFlowAction { EnterCallee, ExitCallee, ExternalCallee };
//===----------------------------------------------------------------------===//
// AbstractDenseLattice
@@ -131,14 +132,21 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// Propagate the dense lattice forward along the call control flow edge,
/// which can be either entering or exiting the callee. Default implementation
- /// just meets the states, meaning that operations implementing
- /// `CallOpInterface` don't have any effect on the lattice that isn't already
- /// expressed by the interface itself.
+ /// for enter and exit callee actions just meets the states, meaning that
+ /// operations implementing `CallOpInterface` don't have any effect on the
+ /// lattice that isn't already expressed by the interface itself. Default
+ /// implementation for the external callee action additionally sets the
+ /// "after" lattice to the entry state.
virtual void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
const AbstractDenseLattice &before,
AbstractDenseLattice *after) {
join(after, before);
+ // Note that `setToEntryState` may be a "partial fixpoint" for some
+ // lattices, e.g., lattices that are lists of maps of other lattices will
+ // only set fixpoint for "known" lattices.
+ if (action == CallControlFlowAction::ExternalCallee)
+ setToEntryState(after);
}
/// Visit a program point within a region branch operation with predecessors
@@ -155,7 +163,9 @@ class AbstractDenseForwardDataFlowAnalysis : public DataFlowAnalysis {
/// Visit an operation for which the data flow is described by the
/// `CallOpInterface`.
- void visitCallOperation(CallOpInterface call, AbstractDenseLattice *after);
+ void visitCallOperation(CallOpInterface call,
+ const AbstractDenseLattice &before,
+ AbstractDenseLattice *after);
};
//===----------------------------------------------------------------------===//
@@ -361,14 +371,22 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// Propagate the dense lattice backwards along the call control flow edge,
/// which can be either entering or exiting the callee. Default implementation
- /// just meets the states, meaning that operations implementing
- /// `CallOpInterface` don't have any effect on hte lattice that isn't already
- /// expressed by the interface itself.
+ /// for enter and exit callee action just meets the states, meaning that
+ /// operations implementing `CallOpInterface` don't have any effect on the
+ /// lattice that isn't already expressed by the interface itself. Default
+ /// implementation for external callee action additional sets the result to
+ /// the exit (fixpoint) state.
virtual void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
meet(before, after);
+
+ // Note that `setToExitState` may be a "partial fixpoint" for some lattices,
+ // e.g., lattices that are lists of maps of other lattices will only
+ // set fixpoint for "known" lattices.
+ if (action == CallControlFlowAction::ExternalCallee)
+ setToExitState(before);
}
private:
@@ -394,7 +412,9 @@ class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis {
/// otherwise,
/// - meet that state with the state before the call-like op, or use the
/// custom logic if overridden by concrete analyses.
- void visitCallOperation(CallOpInterface call, AbstractDenseLattice *before);
+ void visitCallOperation(CallOpInterface call,
+ const AbstractDenseLattice &after,
+ AbstractDenseLattice *before);
/// Symbol table for call-level control flow.
SymbolTableCollection &symbolTable;
diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
index 5a9a36159b56c5..b65ac8bb1dec27 100644
--- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
+++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h
@@ -17,6 +17,7 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -199,6 +200,12 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis {
ArrayRef<const AbstractSparseLattice *> operandLattices,
ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
+ /// The transfer function for calls to external functions.
+ virtual void visitExternalCallImpl(
+ CallOpInterface call,
+ ArrayRef<const AbstractSparseLattice *> argumentLattices,
+ ArrayRef<AbstractSparseLattice *> resultLattices) = 0;
+
/// Given an operation with region control-flow, the lattices of the operands,
/// and a region successor, compute the lattice values for block arguments
/// that are not accounted for by the branching control flow (ex. the bounds
@@ -271,6 +278,14 @@ class SparseForwardDataFlowAnalysis
virtual void visitOperation(Operation *op, ArrayRef<const StateT *> operands,
ArrayRef<StateT *> results) = 0;
+ /// Visit a call operation to an externally defined function given the
+ /// lattices of its arguments.
+ virtual void visitExternalCall(CallOpInterface call,
+ ArrayRef<const StateT *> argumentLattices,
+ ArrayRef<StateT *> resultLattices) {
+ setAllToEntryStates(resultLattices);
+ }
+
/// Given an operation with possible region control-flow, the lattices of the
/// operands, and a region successor, compute the lattice values for block
/// arguments that are not accounted for by the branching control flow (ex.
@@ -321,6 +336,17 @@ class SparseForwardDataFlowAnalysis
{reinterpret_cast<StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
+ void visitExternalCallImpl(
+ CallOpInterface call,
+ ArrayRef<const AbstractSparseLattice *> argumentLattices,
+ ArrayRef<AbstractSparseLattice *> resultLattices) override {
+ visitExternalCall(
+ call,
+ {reinterpret_cast<const StateT *const *>(argumentLattices.begin()),
+ argumentLattices.size()},
+ {reinterpret_cast<StateT *const *>(resultLattices.begin()),
+ resultLattices.size()});
+ }
void visitNonControlFlowArgumentsImpl(
Operation *op, const RegionSuccessor &successor,
ArrayRef<AbstractSparseLattice *> argLattices,
@@ -363,6 +389,11 @@ class AbstractSparseBackwardDataFlowAnalysis : public DataFlowAnalysis {
Operation *op, ArrayRef<AbstractSparseLattice *> operandLattices,
ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
+ /// The transfer function for calls to external functions.
+ virtual void visitExternalCallImpl(
+ CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
+ ArrayRef<const AbstractSparseLattice *> resultLattices) = 0;
+
// Visit operands on branch instructions that are not forwarded.
virtual void visitBranchOperand(OpOperand &operand) = 0;
@@ -444,6 +475,19 @@ class SparseBackwardDataFlowAnalysis
virtual void visitOperation(Operation *op, ArrayRef<StateT *> operands,
ArrayRef<const StateT *> results) = 0;
+ /// Visit a call to an external function. This function is expected to set
+ /// lattice values of the call operands. By default, calls `visitCallOperand`
+ /// for all operands.
+ virtual void visitExternalCall(CallOpInterface call,
+ ArrayRef<StateT *> argumentLattices,
+ ArrayRef<const StateT *> resultLattices) {
+ (void)argumentLattices;
+ (void)resultLattices;
+ for (OpOperand &operand : call->getOpOperands()) {
+ visitCallOperand(operand);
+ }
+ };
+
protected:
/// Get the lattice element for a value.
StateT *getLatticeElement(Value value) override {
@@ -474,6 +518,17 @@ class SparseBackwardDataFlowAnalysis
{reinterpret_cast<const StateT *const *>(resultLattices.begin()),
resultLattices.size()});
}
+
+ void visitExternalCallImpl(
+ CallOpInterface call, ArrayRef<AbstractSparseLattice *> operandLattices,
+ ArrayRef<const AbstractSparseLattice *> resultLattices) override {
+ visitExternalCall(
+ call,
+ {reinterpret_cast<StateT *const *>(operandLattices.begin()),
+ operandLattices.size()},
+ {reinterpret_cast<const StateT *const *>(resultLattices.begin()),
+ resultLattices.size()});
+ }
};
} // end namespace dataflow
diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h
index c27615b52a12b8..541cdb1e237c1b 100644
--- a/mlir/include/mlir/Analysis/DataFlowFramework.h
+++ b/mlir/include/mlir/Analysis/DataFlowFramework.h
@@ -175,6 +175,32 @@ struct ProgramPoint
/// Forward declaration of the data-flow analysis class.
class DataFlowAnalysis;
+//===----------------------------------------------------------------------===//
+// DataFlowConfig
+//===----------------------------------------------------------------------===//
+
+/// Configuration class for data flow solver and child analyses. Follows the
+/// fluent API pattern.
+class DataFlowConfig {
+public:
+ DataFlowConfig() = default;
+
+ /// Set whether the solver should operate interpocedurally, i.e. enter the
+ /// callee body when available. Interprocedural analyses may be more precise,
+ /// but also more expensive as more states need to be computed and the
+ /// fixpoint convergence takes longer.
+ DataFlowConfig &setInterprocedural(bool enable) {
+ interprocedural = enable;
+ return *this;
+ }
+
+ /// Return `true` if the solver operates interprocedurally, `false` otherwise.
+ bool isInterprocedural() const { return interprocedural; }
+
+private:
+ bool interprocedural = true;
+};
+
//===----------------------------------------------------------------------===//
// DataFlowSolver
//===----------------------------------------------------------------------===//
@@ -195,6 +221,9 @@ class DataFlowAnalysis;
/// TODO: Optimize the internal implementation of the solver.
class DataFlowSolver {
public:
+ explicit DataFlowSolver(const DataFlowConfig &config = DataFlowConfig())
+ : config(config) {}
+
/// Load an analysis into the solver. Return the analysis instance.
template <typename AnalysisT, typename... Args>
AnalysisT *load(Args &&...args);
@@ -236,7 +265,13 @@ class DataFlowSolver {
/// dependent work items to the back of the queue.
void propagateIfChanged(AnalysisState *state, ChangeResult changed);
+ /// Get the configuration of the solver.
+ const DataFlowConfig &getConfig() const { return config; }
+
private:
+ /// Configuration of the dataflow solver.
+ DataFlowConfig config;
+
/// The solver's work queue. Work items can be inserted to the front of the
/// queue to be processed greedily, speeding up computations that otherwise
/// quickly degenerate to quadratic due to propagation of state updates.
@@ -423,6 +458,9 @@ class DataFlowAnalysis {
return state;
}
+ /// Return the configuration of the solver used for this analysis.
+ const DataFlowConfig &getSolverConfig() const { return solver.getConfig(); }
+
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// When compiling with debugging, keep a name for the analyis.
StringRef debugName;
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index a6c9f7d7da225e..08d89d6db788c8 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -54,12 +54,22 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint point) {
}
void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
- CallOpInterface call, AbstractDenseLattice *after) {
+ CallOpInterface call, const AbstractDenseLattice &before,
+ AbstractDenseLattice *after) {
+ // Allow for customizing the behavior of calls to external symbols, including
+ // when the analysis is explicitly marked as non-interprocedural.
+ auto callable =
+ dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+ if (!getSolverConfig().isInterprocedural() ||
+ (callable && !callable.getCallableRegion())) {
+ return visitCallControlFlowTransfer(
+ call, CallControlFlowAction::ExternalCallee, before, after);
+ }
const auto *predecessors =
getOrCreateFor<PredecessorState>(call.getOperation(), call);
- // If not all return sites are known, then conservatively assume we can't
- // reason about the data-flow.
+ // Otherwise, if not all return sites are known, then conservatively assume we
+ // can't reason about the data-flow.
if (!predecessors->allPredecessorsKnown())
return setToEntryState(after);
@@ -108,7 +118,7 @@ void AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
// If this is a call operation, then join its lattices across known return
// sites.
if (auto call = dyn_cast<CallOpInterface>(op))
- return visitCallOperation(call, after);
+ return visitCallOperation(call, *before, after);
// Invoke the operation transfer function.
visitOperationImpl(op, *before, after);
@@ -130,8 +140,10 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
if (callable && callable.getCallableRegion() == block->getParent()) {
const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
// If not all callsites are known, conservatively mark all lattices as
- // having reached their pessimistic fixpoints.
- if (!callsites->allPredecessorsKnown())
+ // having reached their pessimistic fixpoints. Do the same if
+ // interprocedural analysis is not enabled.
+ if (!callsites->allPredecessorsKnown() ||
+ !getSolverConfig().isInterprocedural())
return setToEntryState(after);
for (Operation *callsite : callsites->getKnownPredecessors()) {
// Get the dense lattice before the callsite.
@@ -267,18 +279,20 @@ LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) {
}
void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
- CallOpInterface call, AbstractDenseLattice *before) {
+ CallOpInterface call, const AbstractDenseLattice &after,
+ AbstractDenseLattice *before) {
// Find the callee.
Operation *callee = call.resolveCallable(&symbolTable);
auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
if (!callable)
return setToExitState(before);
- // No region means the callee is only declared in this module and we shouldn't
- // assume anything about it.
+ // No region means the callee is only declared in this module.
Region *region = callable.getCallableRegion();
- if (!region || region->empty())
- return setToExitState(before);
+ if (!region || region->empty() || !getSolverConfig().isInterprocedural()) {
+ return visitCallControlFlowTransfer(
+ call, CallControlFlowAction::ExternalCallee, after, before);
+ }
// Call-level control flow specifies the data flow here.
//
@@ -324,7 +338,7 @@ void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
return visitRegionBranchOperation(op, branch, RegionBranchPoint::parent(),
before);
if (auto call = dyn_cast<CallOpInterface>(op))
- return visitCallOperation(call, before);
+ return visitCallOperation(call, *after, before);
// Invoke the operation transfer function.
visitOperationImpl(op, *after, before);
@@ -359,8 +373,10 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
// If not all call sites are known, conservative mark all lattices as
// having reached their pessimistic fix points.
- if (!callsites->allPredecessorsKnown())
+ if (!callsites->allPredecessorsKnown() ||
+ !getSolverConfig().isInterprocedural()) {
return setToExitState(before);
+ }
for (Operation *callsite : callsites->getKnownPredecessors()) {
const AbstractDenseLattice *after;
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 9f544d656df925..b47bba16fd9024 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -116,8 +116,27 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
resultLattices);
}
- // The results of a call operation are determined by the callgraph.
+ // Grab the lattice elements of the operands.
+ SmallVector<const AbstractSparseLattice *> operandLattices;
+ operandLattices.reserve(op->getNumOperands());
+ for (Value operand : op->getOperands()) {
+ AbstractSparseLattice *operandLattice = getLatticeElement(operand);
+ operandLattice->useDefSubscribe(this);
+ operandLattices.push_back(operandLattice);
+ }
+
if (auto call = dyn_cast<CallOpInterface>(op)) {
+ // If the call operation is to an external function, attempt to infer the
+ // results from the call arguments.
+ auto callable =
+ dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
+ if (!getSolverConfig().isInterprocedural() ||
+ (callable && !callable.getCallableRegion())) {
+ return visitExternalCallImpl(call, operandLattices, resultLattices);
+ }
+
+ // Otherwise, the results of a call operation are determined by the
+ // callgraph.
const auto *predecessors = getOrCreateFor<PredecessorState>(op, call);
// If not all return sites are known, then conservatively assume we can't
// reason about the data-flow.
@@ -129,15 +148,6 @@ void AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
return;
}
- // Grab the lattice elements of the operands.
- SmallVector<const AbstractSparseLattice *> operandLattices;
- operandLattices.reserve(op->getNumOperands());
- for (Value operand : op->getOperands()) {
- AbstractSparseLattice *operandLattice = getLatticeElement(operand);
- operandLattice->useDefSubscribe(this);
- operandLattices.push_back(operandLattice);
- }
-
// Invoke the operation transfer function.
visitOperationImpl(op, operandLattices, resultLattices);
}
@@ -168,8 +178,10 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
const auto *callsites = getOrCreateFor<PredecessorState>(block, callable);
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints.
- if (!callsites->allPredecessorsKnown())
+ if (!callsites->allPredecessorsKnown() ||
+ !getSolverConfig().isInterprocedural()) {
return setAllToEntryStates(argLattices);
+ }
for (Operation *callsite : callsites->getKnownPredecessors()) {
auto call = cast<CallOpInterface>(callsite);
for (auto it : llvm::zip(call.getArgOperands(), argLattices))
@@ -433,19 +445,26 @@ void AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// stored in `unaccounted`.
BitVector unaccounted(op->getNumOperands(), true);
+ // If the call invokes an external function (or a function treated as
+ // external due to config), defer to the corresponding extension hook.
+ // By default, it just does `visitCallOperand` for all operands.
OperandRange argOperands = call.getArgOperands();
MutableArrayRef<OpOperand> argOpOperands =
operandsToOpOperands(argOperands);
Region *region = callable.getCallableRegion();
- if (region && !region->empty()) {
- Block &block = region->front();
- for (auto [blockArg, argOpOperand] :
- llvm::zip(block.getArguments(), argOpOperands)) {
- meet(getLatticeElement(argOpOperand.get()),
- *getLatticeElementFor(op, blockArg));
- unaccounted.reset(argOpOperand.getOperandNumber());
- }
+ if (!region || region->empty() || !getSolverConfig().isInterprocedural())
+ return visitExternalCallImpl(call, operandLattices, resultLattices);
+
+ // Otherwise, propagate information from the entry point of the function
+ // back to operands whenever possible.
+ Block &block = region->front();
+ for (auto [blockArg, argOpOperand] :
+ llvm::zip(block.getArguments(), argOpOperands)) {
+ meet(getLatticeElement(argOpOperand.get()),
+ *getLatticeElementFor(op, blockArg));
+ unaccounted.reset(argOpOperand.getOperandNumber());
}
+
// Handle the operands of the call op that aren't forwarded to any
// arguments.
for (int index : unaccounted.set_bits()) {
diff --git a/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
index 709d787bb306be..a5eba43ac68ab1 100644
--- a/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
+++ b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir
@@ -1,8 +1,32 @@
-// RUN: mlir-opt -test-last-modified --split-input-file %s 2>&1 | FileCheck %s
+// RUN: mlir-opt -test-last-modified --split-input-file %s 2>&1 |\
+// RUN: FileCheck %s --check-prefixes=CHECK,IP,IP_ONLY
+// RUN: mlir-opt -test-last-modified='assume-func-writes=true' \
+// RUN: --split-input-file %s 2>&1 |\
+// RUN: FileCheck %s --check-prefixes=CHECK,IP,IP_AW
+// RUN: mlir-opt -test-last-modified='interprocedural=false' \
+// RUN: --split-input-file %s 2>&1 |\
+// RUN: FileCheck %s --check-prefixes=CHECK,LOCAL
+// RUN: mlir-opt \
+// RUN: -test-last-modified='interprocedural=false assume-func-writes=true' \
+// RUN: --split-input-file %s 2>&1 |\
+// RUN: FileCheck %s --check-prefixes=CHECK,LC_AW
+
+// Check prefixes are as follows:
+// 'check': common for all runs;
+// 'ip': interprocedural runs;
+// 'ip_aw': interpocedural runs assuming calls to external functions write to
+// all arguments;
+// 'ip_only': interprocedural runs not assuming calls writing;
+// 'local': local (non-interprocedural) analysis not assuming calls writing;
+// 'lc_aw': local analysis assuming external calls writing to all arguments.
// CHECK-LABEL: test_tag: test_callsite
-// CHECK: operand #0
-// CHECK-NEXT: - a
+// IP: operand #0
+// IP-NEXT: - a
+// LOCAL: operand #0
+// LOCAL-NEXT: - <unknown>
+// LC_AW: operand #0
+// LC_AW-NEXT: - <unknown>
func.func private @single_callsite_fn(%ptr: memref<i32>) -> memref<i32> {
return {tag = "test_callsite"} %ptr : memref<i32>
}
@@ -16,8 +40,12 @@ func.func @test_callsite() {
}
// CHECK-LABEL: test_tag: test_return_site
-// CHECK: operand #0
-// CHECK-NEXT: - b
+// IP: operand #0
+// IP-NEXT: - b
+// LOCAL: operand #0
+// LOCAL-NEXT: - <unknown>
+// LC_AW: operand #0
+// LC_AW-NEXT: - <unknown>
func.func private @single_return_site_fn(%ptr: memref<i32>) -> memref<i32> {
%c0 = arith.constant 0 : i32
memref.store %c0, %ptr[] {tag_name = "b"} : memref<i32>
@@ -25,9 +53,13 @@ func.func private @single_return_site_fn(%ptr: memref<i32>) -> memref<i32> {
}
// CHECK-LABEL: test_tag: test_multiple_callsites
-// CHECK: operand #0
-// CHECK-NEXT: write0
-// CHECK-NEXT: write1
+// IP: operand #0
+// IP-NEXT: write0
+// IP-NEXT: write1
+// LOCAL: operand #0
+// LOCAL-NEXT: - <unknown>
+// LC_AW: operand #0
+// LC_AW-NEXT: - <unknown>
func.func @test_return_site(%ptr: memref<i32>) -> memref<i32> {
%0 = func.call @single_return_site_fn(%ptr) : (memref<i32>) -> memref<i32>
return {tag = "test_return_site"} %0 : memref<i32>
@@ -46,9 +78,13 @@ func.func @test_multiple_callsites(%a: i32, %ptr: memref<i32>) -> memref<i32> {
}
// CHECK-LABEL: test_tag: test_multiple_return_sites
-// CHECK: operand #0
-// CHECK-NEXT: return0
-// CHECK-NEXT: return1
+// IP: operand #0
+// IP-NEXT: return0
+// IP-NEXT: return1
+// LOCAL: operand #0
+// LOCAL-NEXT: - <unknown>
+// LC_AW: operand #0
+// LC_AW-NEXT: - <unknown>
func.func private @multiple_return_site_fn(%cond: i1, %a: i32, %ptr: memref<i32>) -> memref<i32> {
cf.cond_br %cond, ^a, ^b
@@ -69,8 +105,12 @@ func.func @test_multiple_return_sites(%cond: i1, %a: i32, %ptr: memref<i32>) ->
// -----
// CHECK-LABEL: test_tag: after_call
-// CHECK: operand #0
-// CHECK-NEXT: - write0
+// IP: operand #0
+// IP-NEXT: - write0
+// LOCAL: operand #0
+// LOCAL-NEXT: - <unknown>
+// LC_AW: operand #0
+// LC_AW-NEXT: - func.call
func.func private @void_return(%ptr: memref<i32>) {
return
}
@@ -98,17 +138,29 @@ func.func private @callee(%arg0: memref<f32>) -> memref<f32> {
// "pre" -> "call" -> "callee" -> "post"
// CHECK-LABEL: test_tag: call_and_store_before::enter_callee:
-// CHECK: operand #0
-// CHECK: - call
+// IP: operand #0
+// IP: - call
+// LOCAL: operand #0
+// LOCAL: - <unknown>
+// LC_AW: operand #0
+// LC_AW: - <unknown>
+
// CHECK: test_tag: exit_callee:
// CHECK: operand #0
// CHECK: - callee
+
// CHECK: test_tag: before_call:
// CHECK: operand #0
// CHECK: - pre
+
// CHECK: test_tag: after_call:
-// CHECK: operand #0
-// CHECK: - callee
+// IP: operand #0
+// IP: - callee
+// LOCAL: operand #0
+// LOCAL: - <unknown>
+// LC_AW: operand #0
+// LC_AW: - call
+
// CHECK: test_tag: return:
// CHECK: operand #0
// CHECK: - post
@@ -138,17 +190,29 @@ func.func private @callee(%arg0: memref<f32>) -> memref<f32> {
// "pre" -> "callee" -> "call" -> "post"
// CHECK-LABEL: test_tag: call_and_store_after::enter_callee:
-// CHECK: operand #0
-// CHECK: - pre
+// IP: operand #0
+// IP: - pre
+// LOCAL: operand #0
+// LOCAL: - <unknown>
+// LC_AW: operand #0
+// LC_AW: - <unknown>
+
// CHECK: test_tag: exit_callee:
// CHECK: operand #0
// CHECK: - callee
+
// CHECK: test_tag: before_call:
// CHECK: operand #0
// CHECK: - pre
-// CHECK: test_tag: after_call:
-// CHECK: operand #0
-// CHECK: - call
+
+// CHECK: test_tag: after_call:
+// IP: operand #0
+// IP: - call
+// LOCAL: operand #0
+// LOCAL: - <unknown>
+// LC_AW: operand #0
+// LC_AW: - call
+
// CHECK: test_tag: return:
// CHECK: operand #0
// CHECK: - post
@@ -162,3 +226,20 @@ func.func @call_and_store_after(%arg0: memref<f32>) -> memref<f32> {
memref.store %1, %arg0[] {tag_name = "post"} : memref<f32>
return {tag = "return"} %arg0 : memref<f32>
}
+
+// -----
+
+func.func private @void_return(%ptr: memref<i32>)
+
+// CHECK-LABEL: test_tag: after_opaque_call:
+// CHECK: operand #0
+// IP_ONLY: - <unknown>
+// IP_AW: - func.call
+func.func @test_opaque_call_return() {
+ %ptr = memref.alloc() : memref<i32>
+ %c0 = arith.constant 0 : i32
+ memref.store %c0, %ptr[] {tag_name = "write0"} : memref<i32>
+ func.call @void_return(%ptr) : (memref<i32>) -> ()
+ memref.load %ptr[] {tag = "after_opaque_call"} : memref<i32>
+ return
+}
diff --git a/mlir/test/Analysis/DataFlow/test-next-access.mlir b/mlir/test/Analysis/DataFlow/test-next-access.mlir
index 313a75c171d01d..de0788fb6a1768 100644
--- a/mlir/test/Analysis/DataFlow/test-next-access.mlir
+++ b/mlir/test/Analysis/DataFlow/test-next-access.mlir
@@ -1,4 +1,22 @@
-// RUN: mlir-opt %s --test-next-access --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --test-next-access --split-input-file |\
+// RUN: FileCheck %s --check-prefixes=CHECK,IP
+// RUN: mlir-opt %s --test-next-access='interprocedural=false' \
+// RUN: --split-input-file |\
+// RUN: FileCheck %s --check-prefixes=CHECK,LOCAL
+// RUN: mlir-opt %s --test-next-access='assume-func-reads=true' \
+// RUN: --split-input-file |\
+// RUN: FileCheck %s --check-prefixes=CHECK,IP_AR
+// RUN: mlir-opt %s \
+// RUN: --test-next-access='interprocedural=false assume-func-reads=true' \
+// RUN: --split-input-file | FileCheck %s --check-prefixes=CHECK,LC_AR
+
+// Check prefixes are as follows:
+// 'check': common for all runs;
+// 'ip_ar': interpocedural runs assuming calls to external functions read
+// all arguments;
+// 'ip': interprocedural runs not assuming function calls reading;
+// 'local': local (non-interprocedural) analysis not assuming calls reading;
+// 'lc_ar': local analysis assuming external calls reading all arguments.
// CHECK-LABEL: @trivial
func.func @trivial(%arg0: memref<f32>, %arg1: f32) -> f32 {
@@ -252,8 +270,10 @@ func.func @known_conditional_cf(%arg0: memref<f32>) {
// -----
func.func private @callee1(%arg0: memref<f32>) {
- // CHECK: name = "callee1"
- // CHECK-SAME: next_access = {{\[}}["post"]]
+ // IP: name = "callee1"
+ // IP-SAME: next_access = {{\[}}["post"]]
+ // LOCAL: name = "callee1"
+ // LOCAL-SAME: next_access = ["unknown"]
memref.load %arg0[] {name = "callee1"} : memref<f32>
return
}
@@ -267,10 +287,14 @@ func.func private @callee2(%arg0: memref<f32>) {
// CHECK-LABEL: @simple_call
func.func @simple_call(%arg0: memref<f32>) {
- // CHECK: name = "caller"
- // CHECK-SAME: next_access = {{\[}}["callee1"]]
+ // IP: name = "caller"
+ // IP-SAME: next_access = {{\[}}["callee1"]]
+ // LOCAL: name = "caller"
+ // LOCAL-SAME: next_access = ["unknown"]
+ // LC_AR: name = "caller"
+ // LC_AR-SAME: next_access = {{\[}}["call"]]
memref.load %arg0[] {name = "caller"} : memref<f32>
- func.call @callee1(%arg0) : (memref<f32>) -> ()
+ func.call @callee1(%arg0) {name = "call"} : (memref<f32>) -> ()
memref.load %arg0[] {name = "post"} : memref<f32>
return
}
@@ -279,10 +303,14 @@ func.func @simple_call(%arg0: memref<f32>) {
// CHECK-LABEL: @infinite_recursive_call
func.func @infinite_recursive_call(%arg0: memref<f32>) {
- // CHECK: name = "pre"
- // CHECK-SAME: next_access = {{\[}}["pre"]]
+ // IP: name = "pre"
+ // IP-SAME: next_access = {{\[}}["pre"]]
+ // LOCAL: name = "pre"
+ // LOCAL-SAME: next_access = ["unknown"]
+ // LC_AR: name = "pre"
+ // LC_AR-SAME: next_access = {{\[}}["call"]]
memref.load %arg0[] {name = "pre"} : memref<f32>
- func.call @infinite_recursive_call(%arg0) : (memref<f32>) -> ()
+ func.call @infinite_recursive_call(%arg0) {name = "call"} : (memref<f32>) -> ()
memref.load %arg0[] {name = "post"} : memref<f32>
return
}
@@ -291,11 +319,15 @@ func.func @infinite_recursive_call(%arg0: memref<f32>) {
// CHECK-LABEL: @recursive_call
func.func @recursive_call(%arg0: memref<f32>, %cond: i1) {
- // CHECK: name = "pre"
- // CHECK-SAME: next_access = {{\[}}["post", "pre"]]
+ // IP: name = "pre"
+ // IP-SAME: next_access = {{\[}}["post", "pre"]]
+ // LOCAL: name = "pre"
+ // LOCAL-SAME: next_access = ["unknown"]
+ // LC_AR: name = "pre"
+ // LC_AR-SAME: next_access = {{\[}}["post", "call"]]
memref.load %arg0[] {name = "pre"} : memref<f32>
scf.if %cond {
- func.call @recursive_call(%arg0, %cond) : (memref<f32>, i1) -> ()
+ func.call @recursive_call(%arg0, %cond) {name = "call"} : (memref<f32>, i1) -> ()
}
memref.load %arg0[] {name = "post"} : memref<f32>
return
@@ -305,12 +337,16 @@ func.func @recursive_call(%arg0: memref<f32>, %cond: i1) {
// CHECK-LABEL: @recursive_call_cf
func.func @recursive_call_cf(%arg0: memref<f32>, %cond: i1) {
- // CHECK: name = "pre"
- // CHECK-SAME: next_access = {{\[}}["pre", "post"]]
+ // IP: name = "pre"
+ // IP-SAME: next_access = {{\[}}["pre", "post"]]
+ // LOCAL: name = "pre"
+ // LOCAL-SAME: next_access = ["unknown"]
+ // LC_AR: name = "pre"
+ // LC_AR-SAME: next_access = {{\[}}["call", "post"]]
%0 = memref.load %arg0[] {name = "pre"} : memref<f32>
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
- call @recursive_call_cf(%arg0, %cond) : (memref<f32>, i1) -> ()
+ call @recursive_call_cf(%arg0, %cond) {name = "call"} : (memref<f32>, i1) -> ()
cf.br ^bb2
^bb2:
%2 = memref.load %arg0[] {name = "post"} : memref<f32>
@@ -320,27 +356,35 @@ func.func @recursive_call_cf(%arg0: memref<f32>, %cond: i1) {
// -----
func.func private @callee1(%arg0: memref<f32>) {
- // CHECK: name = "callee1"
- // CHECK-SAME: next_access = {{\[}}["post"]]
+ // IP: name = "callee1"
+ // IP-SAME: next_access = {{\[}}["post"]]
+ // LOCAL: name = "callee1"
+ // LOCAL-SAME: next_access = ["unknown"]
memref.load %arg0[] {name = "callee1"} : memref<f32>
return
}
func.func private @callee2(%arg0: memref<f32>) {
- // CHECK: name = "callee2"
- // CHECK-SAME: next_access = {{\[}}["post"]]
+ // IP: name = "callee2"
+ // IP-SAME: next_access = {{\[}}["post"]]
+ // LOCAL: name = "callee2"
+ // LOCAL-SAME: next_access = ["unknown"]
memref.load %arg0[] {name = "callee2"} : memref<f32>
return
}
func.func @conditonal_call(%arg0: memref<f32>, %cond: i1) {
- // CHECK: name = "pre"
- // CHECK-SAME: next_access = {{\[}}["callee1", "callee2"]]
+ // IP: name = "pre"
+ // IP-SAME: next_access = {{\[}}["callee1", "callee2"]]
+ // LOCAL: name = "pre"
+ // LOCAL-SAME: next_access = ["unknown"]
+ // LC_AR: name = "pre"
+ // LC_AR-SAME: next_access = {{\[}}["call1", "call2"]]
memref.load %arg0[] {name = "pre"} : memref<f32>
scf.if %cond {
- func.call @callee1(%arg0) : (memref<f32>) -> ()
+ func.call @callee1(%arg0) {name = "call1"} : (memref<f32>) -> ()
} else {
- func.call @callee2(%arg0) : (memref<f32>) -> ()
+ func.call @callee2(%arg0) {name = "call2"} : (memref<f32>) -> ()
}
memref.load %arg0[] {name = "post"} : memref<f32>
return
@@ -354,16 +398,22 @@ func.func @conditonal_call(%arg0: memref<f32>, %cond: i1) {
// "caller" -> "call" -> "callee" -> "post"
func.func private @callee(%arg0: memref<f32>) {
- // CHECK: name = "callee"
- // CHECK-SAME-LITERAL: next_access = [["post"]]
+ // IP: name = "callee"
+ // IP-SAME-LITERAL: next_access = [["post"]]
+ // LOCAL: name = "callee"
+ // LOCAL-SAME: next_access = ["unknown"]
memref.load %arg0[] {name = "callee"} : memref<f32>
return
}
// CHECK-LABEL: @call_and_store_before
func.func @call_and_store_before(%arg0: memref<f32>) {
- // CHECK: name = "caller"
- // CHECK-SAME-LITERAL: next_access = [["call"]]
+ // IP: name = "caller"
+ // IP-SAME-LITERAL: next_access = [["call"]]
+ // LOCAL: name = "caller"
+ // LOCAL-SAME: next_access = ["unknown"]
+ // LC_AR: name = "caller"
+ // LC_AR-SAME: next_access = {{\[}}["call"]]
memref.load %arg0[] {name = "caller"} : memref<f32>
// Note that the access after the entire call is "post".
// CHECK: name = "call"
@@ -382,20 +432,26 @@ func.func @call_and_store_before(%arg0: memref<f32>) {
// "caller" -> "callee" -> "call" -> "post"
func.func private @callee(%arg0: memref<f32>) {
- // CHECK: name = "callee"
- // CHECK-SAME-LITERAL: next_access = [["call"]]
+ // IP: name = "callee"
+ // IP-SAME-LITERAL: next_access = [["call"]]
+ // LOCAL: name = "callee"
+ // LOCAL-SAME: next_access = ["unknown"]
memref.load %arg0[] {name = "callee"} : memref<f32>
return
}
// CHECK-LABEL: @call_and_store_after
func.func @call_and_store_after(%arg0: memref<f32>) {
- // CHECK: name = "caller"
- // CHECK-SAME-LITERAL: next_access = [["callee"]]
+ // IP: name = "caller"
+ // IP-SAME-LITERAL: next_access = [["callee"]]
+ // LOCAL: name = "caller"
+ // LOCAL-SAME: next_access = ["unknown"]
+ // LC_AR: name = "caller"
+ // LC_AR-SAME: next_access = {{\[}}["call"]]
memref.load %arg0[] {name = "caller"} : memref<f32>
// CHECK: name = "call"
// CHECK-SAME-LITERAL: next_access = [["post"], ["post"]]
- test.call_and_store @callee(%arg0), %arg0 {name = "call", store_before_call = true} : (memref<f32>, memref<f32>) -> ()
+ test.call_and_store @callee(%arg0), %arg0 {name = "call", store_before_call = false} : (memref<f32>, memref<f32>) -> ()
// CHECK: name = "post"
// CHECK-SAME-LITERAL: next_access = ["unknown"]
memref.load %arg0[] {name = "post"} : memref<f32>
@@ -499,3 +555,23 @@ func.func @store_with_a_region_after_containing_a_load(%arg0: memref<f32>) {
memref.load %arg0[] {name = "post"} : memref<f32>
return
}
+
+// -----
+
+func.func private @opaque_callee(%arg0: memref<f32>)
+
+// CHECK-LABEL: @call_opaque_callee
+func.func @call_opaque_callee(%arg0: memref<f32>) {
+ // IP: name = "pre"
+ // IP-SAME: next_access = ["unknown"]
+ // IP_AR: name = "pre"
+ // IP_AR-SAME: next_access = {{\[}}["call"]]
+ // LOCAL: name = "pre"
+ // LOCAL-SAME: next_access = ["unknown"]
+ // LC_AR: name = "pre"
+ // LC_AR-SAME: next_access = {{\[}}["call"]]
+ memref.load %arg0[] {name = "pre"} : memref<f32>
+ func.call @opaque_callee(%arg0) {name = "call"} : (memref<f32>) -> ()
+ memref.load %arg0[] {name = "post"} : memref<f32>
+ return
+}
diff --git a/mlir/test/Analysis/DataFlow/test-written-to.mlir b/mlir/test/Analysis/DataFlow/test-written-to.mlir
index 82fe755aaf5d46..4fc9af164d48e8 100644
--- a/mlir/test/Analysis/DataFlow/test-written-to.mlir
+++ b/mlir/test/Analysis/DataFlow/test-written-to.mlir
@@ -1,4 +1,28 @@
-// RUN: mlir-opt -split-input-file -test-written-to %s 2>&1 | FileCheck %s
+// RUN: mlir-opt -split-input-file -test-written-to %s 2>&1 |\
+// RUN: FileCheck %s --check-prefixes=CHECK,IP
+// RUN: mlir-opt -split-input-file -test-written-to='interprocedural=false' %s \
+// RUN: 2>&1 | FileCheck %s --check-prefixes=CHECK,LOCAL
+// RUN: mlir-opt -split-input-file \
+// RUN: -test-written-to='assume-func-writes=true' %s 2>&1 |\
+// RUN: FileCheck %s --check-prefixes=CHECK,IP_AW
+// RUN: mlir-opt -split-input-file \
+// RUN: -test-written-to='interprocedural=false assume-func-writes=true' \
+// RUN: %s 2>&1 | FileCheck %s --check-prefixes=CHECK,LC_AW
+
+// Check prefixes are as follows:
+// 'check': common for all runs;
+// 'ip': interprocedural runs;
+// 'ip_aw': interpocedural runs assuming calls to external functions write to
+// all arguments;
+// 'local': local (non-interprocedural) analysis not assuming calls writing;
+// 'lc_aw': local analysis assuming external calls writing to all arguments.
+
+// Note that despite the name of the test analysis being "written to", it is set
+// up in a peculiar way where passing a value through a block or region argument
+// (via visitCall/BranchOperand) is considered as "writing" that value to the
+// corresponding operand, which is itself a value and not necessarily "memory".
+// This is arguably okay for testing purposes, but may be surprising for readers
+// trying to interpret this test using their intuition.
// CHECK-LABEL: test_tag: constant0
// CHECK: result #0: [a]
@@ -105,7 +129,9 @@ func.func @test_switch(%flag: i32, %m0: memref<i32>) {
// -----
// CHECK-LABEL: test_tag: add
-// CHECK: result #0: [a]
+// IP: result #0: [a]
+// LOCAL: result #0: [callarg0]
+// LC_AW: result #0: [func.call]
func.func @test_caller(%m0: memref<f32>, %arg: f32) {
%0 = arith.addf %arg, %arg {tag = "add"} : f32
%1 = func.call @callee(%0) : (f32) -> f32
@@ -130,7 +156,9 @@ func.func private @callee(%0 : f32) -> f32 {
}
// CHECK-LABEL: test_tag: sub
-// CHECK: result #0: [a]
+// IP: result #0: [a]
+// LOCAL: result #0: [callarg0]
+// LC_AW: result #0: [func.call]
func.func @test_caller_below_callee(%m0: memref<f32>, %arg: f32) {
%0 = arith.subf %arg, %arg {tag = "sub"} : f32
%1 = func.call @callee(%0) : (f32) -> f32
@@ -155,7 +183,9 @@ func.func private @callee3(%0 : f32) -> f32 {
}
// CHECK-LABEL: test_tag: mul
-// CHECK: result #0: [a]
+// IP: result #0: [a]
+// LOCAL: result #0: [callarg0]
+// LC_AW: result #0: [func.call]
func.func @test_callchain(%m0: memref<f32>, %arg: f32) {
%0 = arith.mulf %arg, %arg {tag = "mul"} : f32
%1 = func.call @callee1(%0) : (f32) -> f32
@@ -239,19 +269,19 @@ func.func @test_for(%m0: memref<i32>) {
// -----
// CHECK-LABEL: test_tag: default_a
-// CHECK-LABEL: result #0: [a]
+// CHECK: result #0: [a]
// CHECK-LABEL: test_tag: default_b
-// CHECK-LABEL: result #0: [b]
+// CHECK: result #0: [b]
// CHECK-LABEL: test_tag: 1a
-// CHECK-LABEL: result #0: [a]
+// CHECK: result #0: [a]
// CHECK-LABEL: test_tag: 1b
-// CHECK-LABEL: result #0: [b]
+// CHECK: result #0: [b]
// CHECK-LABEL: test_tag: 2a
-// CHECK-LABEL: result #0: [a]
+// CHECK: result #0: [a]
// CHECK-LABEL: test_tag: 2b
-// CHECK-LABEL: result #0: [b]
+// CHECK: result #0: [b]
// CHECK-LABEL: test_tag: switch
-// CHECK-LABEL: operand #0: [brancharg0]
+// CHECK: operand #0: [brancharg0]
func.func @test_switch(%arg0 : index, %m0: memref<i32>) {
%0, %1 = scf.index_switch %arg0 {tag="switch"} -> i32, i32
case 1 {
@@ -276,6 +306,9 @@ func.func @test_switch(%arg0 : index, %m0: memref<i32>) {
// -----
+// The point of this test is to ensure the analysis doesn't crash in presence of
+// external functions.
+
// CHECK-LABEL: llvm.func @decl(i64)
// CHECK-LABEL: llvm.func @func(%arg0: i64) {
// CHECK-NEXT: llvm.call @decl(%arg0) : (i64) -> ()
@@ -295,12 +328,39 @@ func.func private @callee(%arg0 : i32, %arg1 : i32) -> i32 {
}
// CHECK-LABEL: test_tag: a
-// CHECK-LABEL: operand #0: [b]
-// CHECK-LABEL: operand #1: []
-// CHECK-LABEL: operand #2: [callarg2]
-// CHECK-LABEL: result #0: [b]
+
+// IP: operand #0: [b]
+// LOCAL: operand #0: [callarg0]
+// LC_AW: operand #0: [test.call_on_device]
+
+// IP: operand #1: []
+// LOCAL: operand #1: [callarg1]
+// LC_AW: operand #1: [test.call_on_device]
+
+// IP: operand #2: [callarg2]
+// LOCAL: operand #2: [callarg2]
+// LC_AW: operand #2: [test.call_on_device]
+
+// CHECK: result #0: [b]
func.func @test_call_on_device(%arg0: i32, %arg1: i32, %device: i32, %m0: memref<i32>) {
%0 = test.call_on_device @callee(%arg0, %arg1), %device {tag = "a"} : (i32, i32, i32) -> (i32)
memref.store %0, %m0[] {tag_name = "b"} : memref<i32>
return
}
+
+// -----
+
+func.func private @external_callee(%arg0: i32) -> i32
+
+// CHECK-LABEL: test_tag: add_external
+// IP: operand #0: [callarg0]
+// LOCAL: operand #0: [callarg0]
+// LC_AW: operand #0: [func.call]
+// IP_AW: operand #0: [func.call]
+
+func.func @test_external_callee(%arg0: i32, %m0: memref<i32>) {
+ %0 = arith.addi %arg0, %arg0 { tag = "add_external"}: i32
+ %1 = func.call @external_callee(%arg0) : (i32) -> i32
+ memref.store %1, %m0[] {tag_name = "a"} : memref<i32>
+ return
+}
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
index 8bfd01d828060a..ca052392f2f5f2 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp
@@ -49,7 +49,10 @@ class NextAccess : public AbstractDenseLattice, public AccessLatticeBase {
class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
public:
- using DenseBackwardDataFlowAnalysis::DenseBackwardDataFlowAnalysis;
+ NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
+ bool assumeFuncReads = false)
+ : DenseBackwardDataFlowAnalysis(solver, symbolTable),
+ assumeFuncReads(assumeFuncReads) {}
void visitOperation(Operation *op, const NextAccess &after,
NextAccess *before) override;
@@ -69,8 +72,10 @@ class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
// means "we don't know what the next access is" rather than "there is no next
// access". But it's unclear how to differentiate the two cases...
void setToExitState(NextAccess *lattice) override {
- propagateIfChanged(lattice, lattice->reset());
+ propagateIfChanged(lattice, lattice->setKnownToUnknown());
}
+
+ const bool assumeFuncReads;
};
} // namespace
@@ -84,7 +89,13 @@ void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after,
SmallVector<MemoryEffects::EffectInstance> effects;
memory.getEffects(effects);
- ChangeResult result = before->meet(after);
+
+ // First, check if all underlying values are already known. Otherwise, avoid
+ // propagating and stay in the "undefined" state to avoid incorrectly
+ // propagating values that may be overwritten later on as that could be
+ // problematic for convergence based on monotonicity of lattice updates.
+ SmallVector<Value> underlyingValues;
+ underlyingValues.reserve(effects.size());
for (const MemoryEffects::EffectInstance &effect : effects) {
Value value = effect.getValue();
@@ -95,10 +106,23 @@ void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after,
// If cannot find the most underlying value, we cannot assume anything about
// the next accesses.
- value = UnderlyingValueAnalysis::getMostUnderlyingValue(
- value, [&](Value value) {
- return getOrCreateFor<UnderlyingValueLattice>(op, value);
- });
+ std::optional<Value> underlyingValue =
+ UnderlyingValueAnalysis::getMostUnderlyingValue(
+ value, [&](Value value) {
+ return getOrCreateFor<UnderlyingValueLattice>(op, value);
+ });
+
+ // If the underlying value is not known yet, don't propagate.
+ if (!underlyingValue)
+ return;
+
+ underlyingValues.push_back(*underlyingValue);
+ }
+
+ // Update the state if all underlying values are known.
+ ChangeResult result = before->meet(after);
+ for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) {
+ // If the underlying value is known to be unknown, set to fixpoint.
if (!value)
return setToExitState(before);
@@ -110,6 +134,27 @@ void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after,
void NextAccessAnalysis::visitCallControlFlowTransfer(
CallOpInterface call, CallControlFlowAction action, const NextAccess &after,
NextAccess *before) {
+ if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) {
+ SmallVector<Value> underlyingValues;
+ underlyingValues.reserve(call->getNumOperands());
+ for (Value operand : call.getArgOperands()) {
+ std::optional<Value> underlyingValue =
+ UnderlyingValueAnalysis::getMostUnderlyingValue(
+ operand, [&](Value value) {
+ return getOrCreateFor<UnderlyingValueLattice>(
+ call.getOperation(), value);
+ });
+ if (!underlyingValue)
+ return;
+ underlyingValues.push_back(*underlyingValue);
+ }
+
+ ChangeResult result = before->meet(after);
+ for (Value operand : underlyingValues) {
+ result |= before->set(operand, call);
+ }
+ return propagateIfChanged(before, result);
+ }
auto testCallAndStore =
dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
@@ -143,10 +188,24 @@ void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
namespace {
struct TestNextAccessPass
: public PassWrapper<TestNextAccessPass, OperationPass<>> {
+ TestNextAccessPass() = default;
+ TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) {
+ interprocedural = other.interprocedural;
+ assumeFuncReads = other.assumeFuncReads;
+ }
+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass)
StringRef getArgument() const override { return "test-next-access"; }
+ Option<bool> interprocedural{
+ *this, "interprocedural", llvm::cl::init(true),
+ llvm::cl::desc("perform interprocedural analysis")};
+ Option<bool> assumeFuncReads{
+ *this, "assume-func-reads", llvm::cl::init(false),
+ llvm::cl::desc(
+ "assume external functions have read effect on all arguments")};
+
static constexpr llvm::StringLiteral kTagAttrName = "name";
static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access";
static constexpr llvm::StringLiteral kAtEntryPointAttrName =
@@ -158,22 +217,29 @@ struct TestNextAccessPass
if (!nextAccess)
return StringAttr::get(op->getContext(), "not computed");
+ // Note that if the underlying value could not be computed or is unknown, we
+ // conservatively treat the result also unknown.
SmallVector<Attribute> attrs;
for (Value operand : op->getOperands()) {
- Value value = UnderlyingValueAnalysis::getMostUnderlyingValue(
- operand, [&](Value value) {
- return solver.lookupState<UnderlyingValueLattice>(value);
- });
- std::optional<ArrayRef<Operation *>> nextAcc =
- nextAccess->getAdjacentAccess(value);
- if (!nextAcc) {
+ std::optional<Value> underlyingValue =
+ UnderlyingValueAnalysis::getMostUnderlyingValue(
+ operand, [&](Value value) {
+ return solver.lookupState<UnderlyingValueLattice>(value);
+ });
+ if (!underlyingValue) {
+ attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
+ continue;
+ }
+ Value value = *underlyingValue;
+ const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value);
+ if (!nextAcc || !nextAcc->isKnown()) {
attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
continue;
}
SmallVector<Attribute> innerAttrs;
- innerAttrs.reserve(nextAcc->size());
- for (Operation *nextAccOp : *nextAcc) {
+ innerAttrs.reserve(nextAcc->get().size());
+ for (Operation *nextAccOp : nextAcc->get()) {
if (auto nextAccTag =
nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) {
innerAttrs.push_back(nextAccTag);
@@ -193,9 +259,10 @@ struct TestNextAccessPass
Operation *op = getOperation();
SymbolTableCollection symbolTable;
- DataFlowSolver solver;
+ auto config = DataFlowConfig().setInterprocedural(interprocedural);
+ DataFlowSolver solver(config);
solver.load<DeadCodeAnalysis>();
- solver.load<NextAccessAnalysis>(symbolTable);
+ solver.load<NextAccessAnalysis>(symbolTable, assumeFuncReads);
solver.load<SparseConstantPropagation>();
solver.load<UnderlyingValueAnalysis>();
if (failed(solver.initializeAndRun(op))) {
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
index eab54fbcfbf4ae..61ddc13f8a3d4a 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
@@ -57,6 +57,62 @@ class UnderlyingValue {
std::optional<Value> underlyingValue;
};
+class AdjacentAccess {
+public:
+ using DeterministicSetVector =
+ SetVector<Operation *, SmallVector<Operation *, 2>,
+ SmallPtrSet<Operation *, 2>>;
+
+ ArrayRef<Operation *> get() const { return accesses.getArrayRef(); }
+ bool isKnown() const { return !unknown; }
+
+ ChangeResult merge(const AdjacentAccess &other) {
+ if (unknown)
+ return ChangeResult::NoChange;
+ if (other.unknown) {
+ unknown = true;
+ accesses.clear();
+ return ChangeResult::Change;
+ }
+
+ size_t sizeBefore = accesses.size();
+ accesses.insert(other.accesses.begin(), other.accesses.end());
+ return accesses.size() == sizeBefore ? ChangeResult::NoChange
+ : ChangeResult::Change;
+ }
+
+ ChangeResult set(Operation *op) {
+ if (!unknown && accesses.size() == 1 && *accesses.begin() == op)
+ return ChangeResult::NoChange;
+
+ unknown = false;
+ accesses.clear();
+ accesses.insert(op);
+ return ChangeResult::Change;
+ }
+
+ ChangeResult setUnknown() {
+ if (unknown)
+ return ChangeResult::NoChange;
+
+ accesses.clear();
+ unknown = true;
+ return ChangeResult::Change;
+ }
+
+ bool operator==(const AdjacentAccess &other) const {
+ return unknown == other.unknown && accesses == other.accesses;
+ }
+
+ bool operator!=(const AdjacentAccess &other) const {
+ return !operator==(other);
+ }
+
+private:
+ bool unknown = false;
+ DeterministicSetVector accesses;
+};
+
/// This lattice represents, for a given memory resource, the potential last
/// operations that modified the resource.
class AccessLatticeBase {
@@ -73,40 +129,42 @@ class AccessLatticeBase {
ChangeResult merge(const AccessLatticeBase &rhs) {
ChangeResult result = ChangeResult::NoChange;
for (const auto &mod : rhs.adjAccesses) {
- auto &lhsMod = adjAccesses[mod.first];
- if (lhsMod != mod.second) {
- lhsMod.insert(mod.second.begin(), mod.second.end());
- result |= ChangeResult::Change;
- }
+ AdjacentAccess &lhsMod = adjAccesses[mod.first];
+ result |= lhsMod.merge(mod.second);
}
return result;
}
/// Set the last modification of a value.
ChangeResult set(Value value, Operation *op) {
- auto &lastMod = adjAccesses[value];
+ AdjacentAccess &lastMod = adjAccesses[value];
+ return lastMod.set(op);
+ }
+
+ ChangeResult setKnownToUnknown() {
ChangeResult result = ChangeResult::NoChange;
- if (lastMod.size() != 1 || *lastMod.begin() != op) {
- result = ChangeResult::Change;
- lastMod.clear();
- lastMod.insert(op);
- }
+ for (auto &[value, adjacent] : adjAccesses)
+ result |= adjacent.setUnknown();
return result;
}
/// Get the adjacent accesses to a value. Returns std::nullopt if they
/// are not known.
- std::optional<ArrayRef<Operation *>> getAdjacentAccess(Value value) const {
+ const AdjacentAccess *getAdjacentAccess(Value value) const {
auto it = adjAccesses.find(value);
if (it == adjAccesses.end())
- return {};
- return it->second.getArrayRef();
+ return nullptr;
+ return &it->getSecond();
}
void print(raw_ostream &os) const {
for (const auto &lastMod : adjAccesses) {
os << lastMod.first << ":\n";
- for (Operation *op : lastMod.second)
+ if (!lastMod.second.isKnown()) {
+ os << " <unknown>\n";
+ return;
+ }
+ for (Operation *op : lastMod.second.get())
os << " " << *op << "\n";
}
}
@@ -114,9 +172,7 @@ class AccessLatticeBase {
private:
/// The potential adjacent accesses to a memory resource. Use a set vector to
/// keep the results deterministic.
- DenseMap<Value, SetVector<Operation *, SmallVector<Operation *, 2>,
- SmallPtrSet<Operation *, 2>>>
- adjAccesses;
+ DenseMap<Value, AdjacentAccess> adjAccesses;
};
/// Define the lattice class explicitly to provide a type ID.
@@ -148,7 +204,7 @@ class UnderlyingValueAnalysis
}
/// Look for the most underlying value of a value.
- static Value
+ static std::optional<Value>
getMostUnderlyingValue(Value value,
function_ref<const UnderlyingValueLattice *(Value)>
getUnderlyingValueFn) {
@@ -156,7 +212,7 @@ class UnderlyingValueAnalysis
do {
underlying = getUnderlyingValueFn(value);
if (!underlying || underlying->getValue().isUninitialized())
- return {};
+ return std::nullopt;
Value underlyingValue = underlying->getValue().getUnderlyingValue();
if (underlyingValue == value)
break;
diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
index 2520ed3d83b9ef..29480f5ad63ee0 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
@@ -49,7 +49,9 @@ class LastModification : public AbstractDenseLattice, public AccessLatticeBase {
class LastModifiedAnalysis
: public DenseForwardDataFlowAnalysis<LastModification> {
public:
- using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
+ explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites)
+ : DenseForwardDataFlowAnalysis(solver),
+ assumeFuncWrites(assumeFuncWrites) {}
/// Visit an operation. If the operation has no memory effects, then the state
/// is propagated with no change. If the operation allocates a resource, then
@@ -74,6 +76,9 @@ class LastModifiedAnalysis
void setToEntryState(LastModification *lattice) override {
propagateIfChanged(lattice, lattice->reset());
}
+
+private:
+ const bool assumeFuncWrites;
};
} // end anonymous namespace
@@ -89,7 +94,12 @@ void LastModifiedAnalysis::visitOperation(Operation *op,
SmallVector<MemoryEffects::EffectInstance> effects;
memory.getEffects(effects);
- ChangeResult result = after->join(before);
+ // First, check if all underlying values are already known. Otherwise, avoid
+ // propagating and stay in the "undefined" state to avoid incorrectly
+ // propagating values that may be overwritten later on as that could be
+ // problematic for convergence based on monotonicity of lattice updates.
+ SmallVector<Value> underlyingValues;
+ underlyingValues.reserve(effects.size());
for (const auto &effect : effects) {
Value value = effect.getValue();
@@ -100,10 +110,23 @@ void LastModifiedAnalysis::visitOperation(Operation *op,
// If we cannot find the underlying value, we shouldn't just propagate the
// effects through, return the pessimistic state.
- value = UnderlyingValueAnalysis::getMostUnderlyingValue(
- value, [&](Value value) {
- return getOrCreateFor<UnderlyingValueLattice>(op, value);
- });
+ std::optional<Value> underlyingValue =
+ UnderlyingValueAnalysis::getMostUnderlyingValue(
+ value, [&](Value value) {
+ return getOrCreateFor<UnderlyingValueLattice>(op, value);
+ });
+
+ // If the underlying value is not yet known, don't propagate yet.
+ if (!underlyingValue)
+ return;
+
+ underlyingValues.push_back(*underlyingValue);
+ }
+
+ // Update the state when all underlying values are known.
+ ChangeResult result = after->join(before);
+ for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) {
+ // If the underlying value is known to be unknown, set to fixpoint state.
if (!value)
return setToEntryState(after);
@@ -119,6 +142,26 @@ void LastModifiedAnalysis::visitOperation(Operation *op,
void LastModifiedAnalysis::visitCallControlFlowTransfer(
CallOpInterface call, CallControlFlowAction action,
const LastModification &before, LastModification *after) {
+ if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) {
+ SmallVector<Value> underlyingValues;
+ underlyingValues.reserve(call->getNumOperands());
+ for (Value operand : call.getArgOperands()) {
+ std::optional<Value> underlyingValue =
+ UnderlyingValueAnalysis::getMostUnderlyingValue(
+ operand, [&](Value value) {
+ return getOrCreateFor<UnderlyingValueLattice>(
+ call.getOperation(), value);
+ });
+ if (!underlyingValue)
+ return;
+ underlyingValues.push_back(*underlyingValue);
+ }
+
+ ChangeResult result = after->join(before);
+ for (Value operand : underlyingValues)
+ result |= after->set(operand, call);
+ return propagateIfChanged(after, result);
+ }
auto testCallAndStore =
dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
@@ -155,21 +198,37 @@ struct TestLastModifiedPass
: public PassWrapper<TestLastModifiedPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
+ TestLastModifiedPass() = default;
+ TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) {
+ interprocedural = other.interprocedural;
+ assumeFuncWrites = other.assumeFuncWrites;
+ }
+
StringRef getArgument() const override { return "test-last-modified"; }
+ Option<bool> interprocedural{
+ *this, "interprocedural", llvm::cl::init(true),
+ llvm::cl::desc("perform interprocedural analysis")};
+ Option<bool> assumeFuncWrites{
+ *this, "assume-func-writes", llvm::cl::init(false),
+ llvm::cl::desc(
+ "assume external functions have write effect on all arguments")};
+
void runOnOperation() override {
Operation *op = getOperation();
- DataFlowSolver solver;
+ DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
- solver.load<LastModifiedAnalysis>();
+ solver.load<LastModifiedAnalysis>(assumeFuncWrites);
solver.load<UnderlyingValueAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
raw_ostream &os = llvm::errs();
+ // Note that if the underlying value could not be computed or is unknown, we
+ // conservatively treat the result also unknown.
op->walk([&](Operation *op) {
auto tag = op->getAttrOfType<StringAttr>("tag");
if (!tag)
@@ -180,19 +239,29 @@ struct TestLastModifiedPass
assert(lastMods && "expected a dense lattice");
for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
os << " operand #" << index << "\n";
- Value value = UnderlyingValueAnalysis::getMostUnderlyingValue(
- operand, [&](Value value) {
- return solver.lookupState<UnderlyingValueLattice>(value);
- });
+ std::optional<Value> underlyingValue =
+ UnderlyingValueAnalysis::getMostUnderlyingValue(
+ operand, [&](Value value) {
+ return solver.lookupState<UnderlyingValueLattice>(value);
+ });
+ if (!underlyingValue) {
+ os << " - <unknown>\n";
+ continue;
+ }
+ Value value = *underlyingValue;
assert(value && "expected an underlying value");
- if (std::optional<ArrayRef<Operation *>> lastMod =
+ if (const AdjacentAccess *lastMod =
lastMods->getAdjacentAccess(value)) {
- for (Operation *lastModifier : *lastMod) {
- if (auto tagName =
- lastModifier->getAttrOfType<StringAttr>("tag_name")) {
- os << " - " << tagName.getValue() << "\n";
- } else {
- os << " - " << lastModifier->getName() << "\n";
+ if (!lastMod->isKnown()) {
+ os << " - <unknown>\n";
+ } else {
+ for (Operation *lastModifier : lastMod->get()) {
+ if (auto tagName =
+ lastModifier->getAttrOfType<StringAttr>("tag_name")) {
+ os << " - " << tagName.getValue() << "\n";
+ } else {
+ os << " - " << lastModifier->getName() << "\n";
+ }
}
}
} else {
diff --git a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
index f97a4c8bc5eb3e..e1c60f06a6b5eb 100644
--- a/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
+++ b/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp
@@ -50,7 +50,10 @@ struct WrittenTo : public AbstractSparseLattice {
/// is eventually written to.
class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
public:
- using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
+ WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
+ bool assumeFuncWrites)
+ : SparseBackwardDataFlowAnalysis(solver, symbolTable),
+ assumeFuncWrites(assumeFuncWrites) {}
void visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) override;
@@ -59,7 +62,13 @@ class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
void visitCallOperand(OpOperand &operand) override;
+ void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
+ ArrayRef<const WrittenTo *> results) override;
+
void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
+
+private:
+ bool assumeFuncWrites;
};
void WrittenToAnalysis::visitOperation(Operation *op,
@@ -99,6 +108,26 @@ void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
propagateIfChanged(lattice, lattice->addWrites(newWrites));
}
+void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
+ ArrayRef<WrittenTo *> operands,
+ ArrayRef<const WrittenTo *> results) {
+ if (!assumeFuncWrites) {
+ return SparseBackwardDataFlowAnalysis::visitExternalCall(call, operands,
+ results);
+ }
+
+ for (WrittenTo *lattice : operands) {
+ SetVector<StringAttr> newWrites;
+ StringAttr name = call->getAttrOfType<StringAttr>("tag_name");
+ if (!name) {
+ name = StringAttr::get(call->getContext(),
+ call.getOperation()->getName().getStringRef());
+ }
+ newWrites.insert(name);
+ propagateIfChanged(lattice, lattice->addWrites(newWrites));
+ }
+}
+
} // end anonymous namespace
namespace {
@@ -106,17 +135,31 @@ struct TestWrittenToPass
: public PassWrapper<TestWrittenToPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)
+ TestWrittenToPass() = default;
+ TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) {
+ interprocedural = other.interprocedural;
+ assumeFuncWrites = other.assumeFuncWrites;
+ }
+
StringRef getArgument() const override { return "test-written-to"; }
+ Option<bool> interprocedural{
+ *this, "interprocedural", llvm::cl::init(true),
+ llvm::cl::desc("perform interprocedural analysis")};
+ Option<bool> assumeFuncWrites{
+ *this, "assume-func-writes", llvm::cl::init(false),
+ llvm::cl::desc(
+ "assume external functions have write effect on all arguments")};
+
void runOnOperation() override {
Operation *op = getOperation();
SymbolTableCollection symbolTable;
- DataFlowSolver solver;
+ DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
- solver.load<WrittenToAnalysis>(symbolTable);
+ solver.load<WrittenToAnalysis>(symbolTable, assumeFuncWrites);
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
More information about the Mlir-commits
mailing list