[llvm] 66a0b34 - [Attributor] AAFunctionReachability, Handle CallBase Reachability.
Kuter Dinel via llvm-commits
llvm-commits at lists.llvm.org
Sun Sep 12 15:35:59 PDT 2021
Author: Kuter Dinel
Date: 2021-09-13T01:35:44+03:00
New Revision: 66a0b3464ca8502a6a3b59800f9b31fcd7aa6e97
URL: https://github.com/llvm/llvm-project/commit/66a0b3464ca8502a6a3b59800f9b31fcd7aa6e97
DIFF: https://github.com/llvm/llvm-project/commit/66a0b3464ca8502a6a3b59800f9b31fcd7aa6e97.diff
LOG: [Attributor] AAFunctionReachability, Handle CallBase Reachability.
This patch makes it possible to query callbase reachability
(Can a callbase reach a function Fn transitively).
The patch moves the reachability query handling logic to a member class,
this class will have more users within the AA once we add other function
reachability queries.
Reviewed By: jdoerfert
Differential Revision: https://reviews.llvm.org/D106402
Added:
Modified:
llvm/include/llvm/Transforms/IPO/Attributor.h
llvm/lib/Transforms/IPO/AttributorAttributes.cpp
llvm/unittests/Transforms/IPO/AttributorTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index 97458a988e590..d0ab2f6e3755c 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -4454,6 +4454,9 @@ struct AAFunctionReachability
/// If the function represented by this possition can reach \p Fn.
virtual bool canReach(Attributor &A, Function *Fn) const = 0;
+ /// Can \p CB reach \p Fn
+ virtual bool canReach(Attributor &A, CallBase &CB, Function *Fn) const = 0;
+
/// Create an abstract attribute view for the position \p IRP.
static AAFunctionReachability &createForPosition(const IRPosition &IRP,
Attributor &A);
diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 02f1d38f18984..0f73d1925f421 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -9496,115 +9496,180 @@ struct AACallEdgesFunction : public AACallEdgesImpl {
};
struct AAFunctionReachabilityFunction : public AAFunctionReachability {
- AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A)
- : AAFunctionReachability(IRP, A) {}
+private:
+ struct QuerySet {
+ void markReachable(Function *Fn) {
+ Reachable.insert(Fn);
+ Unreachable.erase(Fn);
+ }
+
+ ChangeStatus update(Attributor &A, const AAFunctionReachability &AA,
+ ArrayRef<const AACallEdges *> AAEdgesList) {
+ ChangeStatus Change = ChangeStatus::UNCHANGED;
+
+ for (auto *AAEdges : AAEdgesList) {
+ if (AAEdges->hasUnknownCallee()) {
+ if (!CanReachUnknownCallee)
+ Change = ChangeStatus::CHANGED;
+ CanReachUnknownCallee = true;
+ return Change;
+ }
+ }
- bool canReach(Attributor &A, Function *Fn) const override {
- // Assume that we can reach any function if we can reach a call with
- // unknown callee.
- if (CanReachUnknownCallee)
- return true;
+ for (Function *Fn : make_early_inc_range(Unreachable)) {
+ if (checkIfReachable(A, AA, AAEdgesList, Fn)) {
+ Change = ChangeStatus::CHANGED;
+ markReachable(Fn);
+ }
+ }
+ return Change;
+ }
- if (ReachableQueries.count(Fn))
- return true;
+ bool isReachable(Attributor &A, const AAFunctionReachability &AA,
+ ArrayRef<const AACallEdges *> AAEdgesList, Function *Fn) {
+ // Assume that we can reach the function.
+ // TODO: Be more specific with the unknown callee.
+ if (CanReachUnknownCallee)
+ return true;
+
+ if (Reachable.count(Fn))
+ return true;
+
+ if (Unreachable.count(Fn))
+ return false;
+
+ // We need to assume that this function can't reach Fn to prevent
+ // an infinite loop if this function is recursive.
+ Unreachable.insert(Fn);
+
+ bool Result = checkIfReachable(A, AA, AAEdgesList, Fn);
+ if (Result)
+ markReachable(Fn);
+ return Result;
+ }
+
+ bool checkIfReachable(Attributor &A, const AAFunctionReachability &AA,
+ ArrayRef<const AACallEdges *> AAEdgesList,
+ Function *Fn) const {
+
+ // Handle the most trivial case first.
+ for (auto *AAEdges : AAEdgesList) {
+ const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges();
+
+ if (Edges.count(Fn))
+ return true;
+ }
+
+ SmallVector<const AAFunctionReachability *, 8> Deps;
+ for (auto &AAEdges : AAEdgesList) {
+ const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges();
+
+ for (Function *Edge : Edges) {
+ // We don't need a dependency if the result is reachable.
+ const AAFunctionReachability &EdgeReachability =
+ A.getAAFor<AAFunctionReachability>(
+ AA, IRPosition::function(*Edge), DepClassTy::NONE);
+ Deps.push_back(&EdgeReachability);
+
+ if (EdgeReachability.canReach(A, Fn))
+ return true;
+ }
+ }
+
+ // The result is false for now, set dependencies and leave.
+ for (auto Dep : Deps)
+ A.recordDependence(AA, *Dep, DepClassTy::REQUIRED);
- if (UnreachableQueries.count(Fn))
return false;
+ }
+
+ /// Set of functions that we know for sure is reachable.
+ DenseSet<Function *> Reachable;
+
+ /// Set of functions that are unreachable, but might become reachable.
+ DenseSet<Function *> Unreachable;
+ /// If we can reach a function with a call to a unknown function we assume
+ /// that we can reach any function.
+ bool CanReachUnknownCallee = false;
+ };
+
+public:
+ AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A)
+ : AAFunctionReachability(IRP, A) {}
+
+ bool canReach(Attributor &A, Function *Fn) const override {
const AACallEdges &AAEdges =
A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED);
- const SetVector<Function *> &Edges = AAEdges.getOptimisticEdges();
- bool Result = checkIfReachable(A, Edges, Fn);
+ // Attributor returns attributes as const, so this function has to be
+ // const for users of this attribute to use it without having to do
+ // a const_cast.
+ // This is a hack for us to be able to cache queries.
+ auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
+ bool Result =
+ NonConstThis->WholeFunction.isReachable(A, *this, {&AAEdges}, Fn);
+
+ return Result;
+ }
+
+ /// Can \p CB reach \p Fn
+ bool canReach(Attributor &A, CallBase &CB, Function *Fn) const override {
+ const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
+ *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
// Attributor returns attributes as const, so this function has to be
// const for users of this attribute to use it without having to do
// a const_cast.
// This is a hack for us to be able to cache queries.
auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
+ QuerySet &CBQuery = NonConstThis->CBQueries[&CB];
- if (Result)
- NonConstThis->ReachableQueries.insert(Fn);
- else
- NonConstThis->UnreachableQueries.insert(Fn);
+ bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, Fn);
return Result;
}
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
- if (CanReachUnknownCallee)
- return ChangeStatus::UNCHANGED;
-
const AACallEdges &AAEdges =
A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED);
- const SetVector<Function *> &Edges = AAEdges.getOptimisticEdges();
ChangeStatus Change = ChangeStatus::UNCHANGED;
- if (AAEdges.hasUnknownCallee()) {
- bool OldCanReachUnknown = CanReachUnknownCallee;
- CanReachUnknownCallee = true;
- return OldCanReachUnknown ? ChangeStatus::UNCHANGED
- : ChangeStatus::CHANGED;
- }
+ Change |= WholeFunction.update(A, *this, {&AAEdges});
- // Check if any of the unreachable functions become reachable.
- for (auto Current = UnreachableQueries.begin();
- Current != UnreachableQueries.end();) {
- if (!checkIfReachable(A, Edges, *Current)) {
- Current++;
- continue;
- }
- ReachableQueries.insert(*Current);
- UnreachableQueries.erase(*Current++);
- Change = ChangeStatus::CHANGED;
+ for (auto CBPair : CBQueries) {
+ const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
+ *this, IRPosition::callsite_function(*CBPair.first),
+ DepClassTy::REQUIRED);
+
+ Change |= CBPair.second.update(A, *this, {&AAEdges});
}
return Change;
}
const std::string getAsStr() const override {
- size_t QueryCount = ReachableQueries.size() + UnreachableQueries.size();
+ size_t QueryCount =
+ WholeFunction.Reachable.size() + WholeFunction.Unreachable.size();
- return "FunctionReachability [" + std::to_string(ReachableQueries.size()) +
- "," + std::to_string(QueryCount) + "]";
+ return "FunctionReachability [" +
+ std::to_string(WholeFunction.Reachable.size()) + "," +
+ std::to_string(QueryCount) + "]";
}
void trackStatistics() const override {}
-
private:
- bool canReachUnknownCallee() const override { return CanReachUnknownCallee; }
-
- bool checkIfReachable(Attributor &A, const SetVector<Function *> &Edges,
- Function *Fn) const {
- if (Edges.count(Fn))
- return true;
-
- for (Function *Edge : Edges) {
- // We don't need a dependency if the result is reachable.
- const AAFunctionReachability &EdgeReachability =
- A.getAAFor<AAFunctionReachability>(*this, IRPosition::function(*Edge),
- DepClassTy::NONE);
-
- if (EdgeReachability.canReach(A, Fn))
- return true;
- }
- for (Function *Fn : Edges)
- A.getAAFor<AAFunctionReachability>(*this, IRPosition::function(*Fn),
- DepClassTy::REQUIRED);
-
- return false;
+ bool canReachUnknownCallee() const override {
+ return WholeFunction.CanReachUnknownCallee;
}
- /// Set of functions that we know for sure is reachable.
- SmallPtrSet<Function *, 8> ReachableQueries;
-
- /// Set of functions that are unreachable, but might become reachable.
- SmallPtrSet<Function *, 8> UnreachableQueries;
+ /// Used to answer if a the whole function can reacha a specific function.
+ QuerySet WholeFunction;
- /// If we can reach a function with a call to a unknown function we assume
- /// that we can reach any function.
- bool CanReachUnknownCallee = false;
+ /// Used to answer if a call base inside this function can reach a specific
+ /// function.
+ DenseMap<CallBase *, QuerySet> CBQueries;
};
} // namespace
diff --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp
index 51de06c6d5e17..e1d2709e883e1 100644
--- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp
+++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp
@@ -109,6 +109,13 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
call void @func5(void ()* @func3)
ret void
}
+
+ define void @func7() {
+ entry:
+ call void @func2()
+ call void @func4()
+ ret void
+ }
)";
Module &M = parseModule(ModuleString);
@@ -127,6 +134,11 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
Function *F3 = M.getFunction("func3");
Function *F4 = M.getFunction("func4");
Function *F6 = M.getFunction("func6");
+ Function *F7 = M.getFunction("func7");
+
+ // call void @func2()
+ CallBase &F7FirstCB =
+ *static_cast<CallBase *>(F7->getEntryBlock().getFirstNonPHI());
const AAFunctionReachability &F1AA =
A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F1));
@@ -134,15 +146,23 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
const AAFunctionReachability &F6AA =
A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F6));
+ const AAFunctionReachability &F7AA =
+ A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F7));
+
F1AA.canReach(A, F3);
F1AA.canReach(A, F4);
F6AA.canReach(A, F4);
+ F7AA.canReach(A, F7FirstCB, F3);
+ F7AA.canReach(A, F7FirstCB, F4);
A.run();
ASSERT_TRUE(F1AA.canReach(A, F3));
ASSERT_FALSE(F1AA.canReach(A, F4));
+ ASSERT_TRUE(F7AA.canReach(A, F7FirstCB, F3));
+ ASSERT_FALSE(F7AA.canReach(A, F7FirstCB, F4));
+
// Assumed to be reacahable, since F6 can reach a function with
// a unknown callee.
ASSERT_TRUE(F6AA.canReach(A, F4));
More information about the llvm-commits
mailing list