[llvm] b2d1ae0 - [Attributor] AAFunctionReachability, Instruction reachability.
Johannes Doerfert via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 31 23:43:03 PST 2022
Author: Kuter Dinel
Date: 2022-02-01T01:40:44-06:00
New Revision: b2d1ae061153b2364f15f58ee00e46b49bad3544
URL: https://github.com/llvm/llvm-project/commit/b2d1ae061153b2364f15f58ee00e46b49bad3544
DIFF: https://github.com/llvm/llvm-project/commit/b2d1ae061153b2364f15f58ee00e46b49bad3544.diff
LOG: [Attributor] AAFunctionReachability, Instruction reachability.
This patch implement instruction reachability for AAFunctionReachability
attribute. It is used to tell if a certain instruction can reach a function
transitively.
NOTE: I created a new commit based of D106720 and set the author back to
Kuter. Other metadata, etc. is wrong. I also addressed the
remaining review comments and fixed the unit test.
Differential Revision: https://reviews.llvm.org/D106720
Added:
Modified:
llvm/include/llvm/Transforms/IPO/Attributor.h
llvm/lib/Transforms/IPO/AttributorAttributes.cpp
llvm/unittests/Transforms/IPO/AttributorTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index b31dceb1b7008..b225a08880d90 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -4616,17 +4616,25 @@ struct AAFunctionReachability
AAFunctionReachability(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
/// If the function represented by this possition can reach \p Fn.
- virtual bool canReach(Attributor &A, Function *Fn) const = 0;
+ virtual bool canReach(Attributor &A, const Function &Fn) const = 0;
/// Can \p CB reach \p Fn
- virtual bool canReach(Attributor &A, CallBase &CB, Function *Fn) const = 0;
+ virtual bool canReach(Attributor &A, CallBase &CB,
+ const Function &Fn) const = 0;
+
+ /// Can \p Inst reach \p Fn
+ virtual bool instructionCanReach(Attributor &A, const Instruction &Inst,
+ const Function &Fn,
+ bool UseBackwards = true) const = 0;
/// Create an abstract attribute view for the position \p IRP.
static AAFunctionReachability &createForPosition(const IRPosition &IRP,
Attributor &A);
/// See AbstractAttribute::getName()
- const std::string getName() const override { return "AAFunctionReachability"; }
+ const std::string getName() const override {
+ return "AAFunctionReachability";
+ }
/// See AbstractAttribute::getIdAddr()
const char *getIdAddr() const override { return &ID; }
diff --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 5aad4ae4742c4..92ce8c141ca6b 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -656,7 +656,7 @@ struct AACallSiteReturnedFromReturned : public BaseType {
if (!AssociatedFunction)
return S.indicatePessimisticFixpoint();
- CallBase &CBContext = static_cast<CallBase &>(this->getAnchorValue());
+ CallBase &CBContext = cast<CallBase>(this->getAnchorValue());
if (IntroduceCallBaseContext)
LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:"
<< CBContext << "\n");
@@ -2468,7 +2468,7 @@ struct AANoRecurseFunction final : AANoRecurseImpl {
const AAFunctionReachability &EdgeReachability =
A.getAAFor<AAFunctionReachability>(*this, getIRPosition(),
DepClassTy::REQUIRED);
- if (EdgeReachability.canReach(A, getAnchorScope()))
+ if (EdgeReachability.canReach(A, *getAnchorScope()))
return indicatePessimisticFixpoint();
return ChangeStatus::UNCHANGED;
}
@@ -9482,7 +9482,7 @@ struct AACallEdgesCallSite : public AACallEdgesImpl {
}
};
- CallBase *CB = static_cast<CallBase *>(getCtxI());
+ CallBase *CB = cast<CallBase>(getCtxI());
if (CB->isInlineAsm()) {
setHasUnknownCallee(false, Change);
@@ -9521,7 +9521,7 @@ struct AACallEdgesFunction : public AACallEdgesImpl {
ChangeStatus Change = ChangeStatus::UNCHANGED;
auto ProcessCallInst = [&](Instruction &Inst) {
- CallBase &CB = static_cast<CallBase &>(Inst);
+ CallBase &CB = cast<CallBase>(Inst);
auto &CBEdges = A.getAAFor<AACallEdges>(
*this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
@@ -9552,11 +9552,39 @@ struct AACallEdgesFunction : public AACallEdgesImpl {
struct AAFunctionReachabilityFunction : public AAFunctionReachability {
private:
struct QuerySet {
- void markReachable(Function *Fn) {
- Reachable.insert(Fn);
- Unreachable.erase(Fn);
+ void markReachable(const Function &Fn) {
+ Reachable.insert(&Fn);
+ Unreachable.erase(&Fn);
}
+ /// If there is no information about the function None is returned.
+ Optional<bool> isCachedReachable(const Function &Fn) {
+ // Assume that we can reach the function.
+ // TODO: Be more specific with the unknown callee.
+ if (CanReachUnknownCallee)
+ return true;
+
+ if (Reachable.count(&Fn))
+ return true;
+
+ if (Unreachable.count(&Fn))
+ return false;
+
+ return llvm::None;
+ }
+
+ /// Set of functions that we know for sure is reachable.
+ DenseSet<const Function *> Reachable;
+
+ /// Set of functions that are unreachable, but might become reachable.
+ DenseSet<const Function *> Unreachable;
+
+ /// If we can reach a function with a call to a unknown function we assume
+ /// that we can reach any function.
+ bool CanReachUnknownCallee = false;
+ };
+
+ struct QueryResolver : public QuerySet {
ChangeStatus update(Attributor &A, const AAFunctionReachability &AA,
ArrayRef<const AACallEdges *> AAEdgesList) {
ChangeStatus Change = ChangeStatus::UNCHANGED;
@@ -9570,31 +9598,25 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
}
}
- for (Function *Fn : make_early_inc_range(Unreachable)) {
- if (checkIfReachable(A, AA, AAEdgesList, Fn)) {
+ for (const Function *Fn : make_early_inc_range(Unreachable)) {
+ if (checkIfReachable(A, AA, AAEdgesList, *Fn)) {
Change = ChangeStatus::CHANGED;
- markReachable(Fn);
+ markReachable(*Fn);
}
}
return Change;
}
bool isReachable(Attributor &A, const AAFunctionReachability &AA,
- ArrayRef<const AACallEdges *> AAEdgesList, Function *Fn) {
- // Assume that we can reach the function.
- // TODO: Be more specific with the unknown callee.
- if (CanReachUnknownCallee)
- return true;
-
- if (Reachable.count(Fn))
- return true;
-
- if (Unreachable.count(Fn))
- return false;
+ ArrayRef<const AACallEdges *> AAEdgesList,
+ const Function &Fn) {
+ Optional<bool> Cached = isCachedReachable(Fn);
+ if (Cached.hasValue())
+ return Cached.getValue();
// We need to assume that this function can't reach Fn to prevent
// an infinite loop if this function is recursive.
- Unreachable.insert(Fn);
+ Unreachable.insert(&Fn);
bool Result = checkIfReachable(A, AA, AAEdgesList, Fn);
if (Result)
@@ -9604,13 +9626,13 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
bool checkIfReachable(Attributor &A, const AAFunctionReachability &AA,
ArrayRef<const AACallEdges *> AAEdgesList,
- Function *Fn) const {
+ const Function &Fn) const {
// Handle the most trivial case first.
for (auto *AAEdges : AAEdgesList) {
const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges();
- if (Edges.count(Fn))
+ if (Edges.count(const_cast<Function *>(&Fn)))
return true;
}
@@ -9631,28 +9653,80 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
}
// The result is false for now, set dependencies and leave.
- for (auto Dep : Deps)
- A.recordDependence(AA, *Dep, DepClassTy::REQUIRED);
+ for (auto *Dep : Deps)
+ A.recordDependence(*Dep, AA, DepClassTy::REQUIRED);
return false;
}
+ };
- /// Set of functions that we know for sure is reachable.
- DenseSet<Function *> Reachable;
+ /// Get call edges that can be reached by this instruction.
+ bool getReachableCallEdges(Attributor &A, const AAReachability &Reachability,
+ const Instruction &Inst,
+ SmallVector<const AACallEdges *> &Result) const {
+ // Determine call like instructions that we can reach from the inst.
+ auto CheckCallBase = [&](Instruction &CBInst) {
+ if (!Reachability.isAssumedReachable(A, Inst, CBInst))
+ return true;
- /// Set of functions that are unreachable, but might become reachable.
- DenseSet<Function *> Unreachable;
+ const auto &CB = cast<CallBase>(CBInst);
+ const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
+ *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
- /// If we can reach a function with a call to a unknown function we assume
- /// that we can reach any function.
- bool CanReachUnknownCallee = false;
- };
+ Result.push_back(&AAEdges);
+ return true;
+ };
+
+ bool UsedAssumedInformation = false;
+ return A.checkForAllCallLikeInstructions(CheckCallBase, *this,
+ UsedAssumedInformation);
+ }
+
+ ChangeStatus checkReachableBackwards(Attributor &A, QuerySet &Set) {
+ ChangeStatus Change = ChangeStatus::UNCHANGED;
+
+ // For all remaining instruction queries, check
+ // callers. A call inside that function might satisfy the query.
+ auto CheckCallSite = [&](AbstractCallSite CallSite) {
+ CallBase *CB = CallSite.getInstruction();
+ if (!CB)
+ return false;
+
+ if (isa<InvokeInst>(CB))
+ return false;
+
+ Instruction *Inst = CB->getNextNonDebugInstruction();
+ const AAFunctionReachability &AA = A.getAAFor<AAFunctionReachability>(
+ *this, IRPosition::function(*Inst->getFunction()),
+ DepClassTy::REQUIRED);
+ for (const Function *Fn : make_early_inc_range(Set.Unreachable)) {
+ if (AA.instructionCanReach(A, *Inst, *Fn, /* UseBackwards */ false)) {
+ Set.markReachable(*Fn);
+ Change = ChangeStatus::CHANGED;
+ }
+ }
+ return true;
+ };
+
+ bool NoUnknownCall = true;
+ if (A.checkForAllCallSites(CheckCallSite, *this, true, NoUnknownCall))
+ return Change;
+
+ // If we don't know all callsites we have to assume that we can reach fn.
+ for (auto &QSet : InstQueriesBackwards) {
+ if (!QSet.second.CanReachUnknownCallee)
+ Change = ChangeStatus::CHANGED;
+ QSet.second.CanReachUnknownCallee = true;
+ }
+
+ return Change;
+ }
public:
AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A)
: AAFunctionReachability(IRP, A) {}
- bool canReach(Attributor &A, Function *Fn) const override {
+ bool canReach(Attributor &A, const Function &Fn) const override {
const AACallEdges &AAEdges =
A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED);
@@ -9668,7 +9742,8 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
}
/// Can \p CB reach \p Fn
- bool canReach(Attributor &A, CallBase &CB, Function *Fn) const override {
+ bool canReach(Attributor &A, CallBase &CB,
+ const Function &Fn) const override {
const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
*this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
@@ -9677,13 +9752,52 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
// a const_cast.
// This is a hack for us to be able to cache queries.
auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
- QuerySet &CBQuery = NonConstThis->CBQueries[&CB];
+ QueryResolver &CBQuery = NonConstThis->CBQueries[&CB];
bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, Fn);
return Result;
}
+ bool instructionCanReach(Attributor &A, const Instruction &Inst,
+ const Function &Fn,
+ bool UseBackwards) const override {
+ const auto &Reachability = &A.getAAFor<AAReachability>(
+ *this, IRPosition::function(*getAssociatedFunction()),
+ DepClassTy::REQUIRED);
+
+ SmallVector<const AACallEdges *> CallEdges;
+ bool AllKnown = getReachableCallEdges(A, *Reachability, Inst, CallEdges);
+ // Attributor returns attributes as const, so this function has to be
+ // const for users of this attribute to use it without having to do
+ // a const_cast.
+ // This is a hack for us to be able to cache queries.
+ auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
+ QueryResolver &InstQSet = NonConstThis->InstQueries[&Inst];
+ if (!AllKnown)
+ InstQSet.CanReachUnknownCallee = true;
+
+ bool ForwardsResult = InstQSet.isReachable(A, *this, CallEdges, Fn);
+ if (ForwardsResult)
+ return true;
+ // We are done.
+ if (!UseBackwards)
+ return false;
+
+ QuerySet &InstBackwardsQSet = NonConstThis->InstQueriesBackwards[&Inst];
+
+ Optional<bool> BackwardsCached = InstBackwardsQSet.isCachedReachable(Fn);
+ if (BackwardsCached.hasValue())
+ return BackwardsCached.getValue();
+
+ // Assume unreachable, to prevent problems.
+ InstBackwardsQSet.Unreachable.insert(&Fn);
+
+ // Check backwards reachability.
+ NonConstThis->checkReachableBackwards(A, InstBackwardsQSet);
+ return InstBackwardsQSet.isCachedReachable(Fn).getValue();
+ }
+
/// See AbstractAttribute::updateImpl(...).
ChangeStatus updateImpl(Attributor &A) override {
const AACallEdges &AAEdges =
@@ -9692,7 +9806,7 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
Change |= WholeFunction.update(A, *this, {&AAEdges});
- for (auto CBPair : CBQueries) {
+ for (auto &CBPair : CBQueries) {
const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
*this, IRPosition::callsite_function(*CBPair.first),
DepClassTy::REQUIRED);
@@ -9700,6 +9814,29 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
Change |= CBPair.second.update(A, *this, {&AAEdges});
}
+ // Update the Instruction queries.
+ const AAReachability *Reachability;
+ if (!InstQueries.empty()) {
+ Reachability = &A.getAAFor<AAReachability>(
+ *this, IRPosition::function(*getAssociatedFunction()),
+ DepClassTy::REQUIRED);
+ }
+
+ // Check for local callbases first.
+ for (auto &InstPair : InstQueries) {
+ SmallVector<const AACallEdges *> CallEdges;
+ bool AllKnown =
+ getReachableCallEdges(A, *Reachability, *InstPair.first, CallEdges);
+ // Update will return change if we this effects any queries.
+ if (!AllKnown)
+ InstPair.second.CanReachUnknownCallee = true;
+ Change |= InstPair.second.update(A, *this, CallEdges);
+ }
+
+ // Update backwards queries.
+ for (auto &QueryPair : InstQueriesBackwards)
+ Change |= checkReachableBackwards(A, QueryPair.second);
+
return Change;
}
@@ -9720,11 +9857,17 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
}
/// Used to answer if a the whole function can reacha a specific function.
- QuerySet WholeFunction;
+ QueryResolver WholeFunction;
/// Used to answer if a call base inside this function can reach a specific
/// function.
- DenseMap<CallBase *, QuerySet> CBQueries;
+ DenseMap<const CallBase *, QueryResolver> CBQueries;
+
+ /// This is for instruction queries than scan "forward".
+ DenseMap<const Instruction *, QueryResolver> InstQueries;
+
+ /// This is for instruction queries than scan "backward".
+ DenseMap<const Instruction *, QuerySet> InstQueriesBackwards;
};
/// ---------------------- Assumption Propagation ------------------------------
diff --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp
index e1d2709e883e1..ea4f1ba70f4c6 100644
--- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp
+++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp
@@ -75,47 +75,69 @@ TEST_F(AttributorTestBase, TestCast) {
TEST_F(AttributorTestBase, AAReachabilityTest) {
const char *ModuleString = R"(
- @x = global i32 0
- define void @func4() {
+ @x = external global i32
+ define internal void @func4() {
store i32 0, i32* @x
ret void
}
- define void @func3() {
+ define internal void @func3() {
store i32 0, i32* @x
ret void
}
- define void @func2() {
+ define internal void @func8() {
+ store i32 0, i32* @x
+ ret void
+ }
+
+ define internal void @func2() {
entry:
call void @func3()
ret void
}
- define void @func1() {
+ define internal void @func1() {
entry:
call void @func2()
ret void
}
- define void @func5(void ()* %unknown) {
+ declare void @unknown()
+ define internal void @func5(void ()* %ptr) {
entry:
- call void %unknown()
+ call void %ptr()
+ call void @unknown()
ret void
}
- define void @func6() {
+ define internal void @func6() {
entry:
call void @func5(void ()* @func3)
ret void
}
- define void @func7() {
+ define internal void @func7() {
+ entry:
+ call void @func2()
+ call void @func4()
+ ret void
+ }
+
+ define internal void @func9() {
entry:
call void @func2()
+ call void @func8()
+ ret void
+ }
+
+ define void @func10() {
+ entry:
+ call void @func9()
call void @func4()
ret void
}
+
)";
Module &M = parseModule(ModuleString);
@@ -128,32 +150,43 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
CallGraphUpdater CGUpdater;
BumpPtrAllocator Allocator;
InformationCache InfoCache(M, AG, Allocator, nullptr);
- Attributor A(Functions, InfoCache, CGUpdater);
+ Attributor A(Functions, InfoCache, CGUpdater, /* Allowed */ nullptr,
+ /*DeleteFns*/ false);
- Function *F1 = M.getFunction("func1");
- Function *F3 = M.getFunction("func3");
- Function *F4 = M.getFunction("func4");
- Function *F6 = M.getFunction("func6");
- Function *F7 = M.getFunction("func7");
+ Function &F1 = *M.getFunction("func1");
+ Function &F3 = *M.getFunction("func3");
+ Function &F4 = *M.getFunction("func4");
+ Function &F6 = *M.getFunction("func6");
+ Function &F7 = *M.getFunction("func7");
+ Function &F9 = *M.getFunction("func9");
// call void @func2()
- CallBase &F7FirstCB =
- *static_cast<CallBase *>(F7->getEntryBlock().getFirstNonPHI());
+ CallBase &F7FirstCB = static_cast<CallBase &>(*F7.getEntryBlock().begin());
+ // call void @func2()
+ Instruction &F9FirstInst = *F9.getEntryBlock().begin();
+ // call void @func8
+ Instruction &F9SecondInst = *++(F9.getEntryBlock().begin());
const AAFunctionReachability &F1AA =
- A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F1));
+ A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(F1));
const AAFunctionReachability &F6AA =
- A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F6));
+ A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(F6));
const AAFunctionReachability &F7AA =
- A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F7));
+ A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(F7));
+
+ const AAFunctionReachability &F9AA =
+ A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(F9));
F1AA.canReach(A, F3);
F1AA.canReach(A, F4);
F6AA.canReach(A, F4);
F7AA.canReach(A, F7FirstCB, F3);
F7AA.canReach(A, F7FirstCB, F4);
+ F9AA.instructionCanReach(A, F9FirstInst, F3);
+ F9AA.instructionCanReach(A, F9SecondInst, F3, false);
+ F9AA.instructionCanReach(A, F9FirstInst, F4);
A.run();
@@ -166,6 +199,15 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
// Assumed to be reacahable, since F6 can reach a function with
// a unknown callee.
ASSERT_TRUE(F6AA.canReach(A, F4));
+
+ // The second instruction of F9 can't reach the first call.
+ ASSERT_FALSE(F9AA.instructionCanReach(A, F9SecondInst, F3, false));
+ ASSERT_FALSE(F9AA.instructionCanReach(A, F9SecondInst, F3, true));
+
+ // The first instruction of F9 can reach the first call.
+ ASSERT_TRUE(F9AA.instructionCanReach(A, F9FirstInst, F3));
+ // Because func10 calls the func4 after the call to func9 it is reachable.
+ ASSERT_TRUE(F9AA.instructionCanReach(A, F9FirstInst, F4));
}
} // namespace llvm
More information about the llvm-commits
mailing list