[llvm] b7d870e - [AssumptionCache] Avoid dangling llvm.assume calls in the cache
Johannes Doerfert via llvm-commits
llvm-commits at lists.llvm.org
Sat Feb 6 10:18:45 PST 2021
Author: Johannes Doerfert
Date: 2021-02-06T12:18:39-06:00
New Revision: b7d870eae7fdadcf10d0f177faa7409c2e37d776
URL: https://github.com/llvm/llvm-project/commit/b7d870eae7fdadcf10d0f177faa7409c2e37d776
DIFF: https://github.com/llvm/llvm-project/commit/b7d870eae7fdadcf10d0f177faa7409c2e37d776.diff
LOG: [AssumptionCache] Avoid dangling llvm.assume calls in the cache
PR49043 exposed a problem when it comes to RAUW llvm.assumes. While
D96106 would fix it for GVNSink, it seems a more general concern. To
avoid future problems this patch moves away from the vector of weak
reference model used in the assumption cache. Instead, we track the
llvm.assume calls with a callback handle which will remove itself from
the cache if the call is deleted.
Fixes PR49043.
Reviewed By: nikic
Differential Revision: https://reviews.llvm.org/D96168
Added:
Modified:
llvm/include/llvm/Analysis/AssumptionCache.h
llvm/lib/Analysis/AssumptionCache.cpp
llvm/lib/Analysis/CodeMetrics.cpp
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
llvm/lib/Transforms/Utils/CodeExtractor.cpp
llvm/lib/Transforms/Utils/PredicateInfo.cpp
llvm/test/Analysis/AssumptionCache/basic.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/AssumptionCache.h b/llvm/include/llvm/Analysis/AssumptionCache.h
index c4602d3449c0..b9ffd9a6c535 100644
--- a/llvm/include/llvm/Analysis/AssumptionCache.h
+++ b/llvm/include/llvm/Analysis/AssumptionCache.h
@@ -18,7 +18,9 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ValueHandle.h"
#include "llvm/Pass.h"
@@ -44,6 +46,22 @@ class AssumptionCache {
/// llvm.assume.
enum : unsigned { ExprResultIdx = std::numeric_limits<unsigned>::max() };
+ /// Callback handle to ensure we do not have dangling pointers to llvm.assume
+ /// calls in our cache.
+ class AssumeHandle final : public CallbackVH {
+ AssumptionCache *AC;
+
+ /// Make sure llvm.assume calls that are deleted are removed from the cache.
+ void deleted() override;
+
+ public:
+ AssumeHandle(Value *V, AssumptionCache *AC = nullptr)
+ : CallbackVH(V), AC(AC) {}
+
+ operator Value *() const { return getValPtr(); }
+ CallInst *getAssumeCI() const { return cast<CallInst>(getValPtr()); }
+ };
+
struct ResultElem {
WeakVH Assume;
@@ -59,9 +77,9 @@ class AssumptionCache {
/// We track this to lazily populate our assumptions.
Function &F;
- /// Vector of weak value handles to calls of the \@llvm.assume
- /// intrinsic.
- SmallVector<ResultElem, 4> AssumeHandles;
+ /// Set of value handles for calls of the \@llvm.assume intrinsic.
+ using AssumeHandleSet = DenseSet<AssumeHandle, DenseMapInfo<Value *>>;
+ AssumeHandleSet AssumeHandles;
class AffectedValueCallbackVH final : public CallbackVH {
AssumptionCache *AC;
@@ -137,13 +155,7 @@ class AssumptionCache {
/// Access the list of assumption handles currently tracked for this
/// function.
- ///
- /// Note that these produce weak handles that may be null. The caller must
- /// handle that case.
- /// FIXME: We should replace this with pointee_iterator<filter_iterator<...>>
- /// when we can write that to filter out the null values. Then caller code
- /// will become simpler.
- MutableArrayRef<ResultElem> assumptions() {
+ AssumeHandleSet &assumptions() {
if (!Scanned)
scanFunction();
return AssumeHandles;
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index 70053fdf8d30..e2a31d6618c4 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -163,7 +163,12 @@ void AssumptionCache::unregisterAssumption(CallInst *CI) {
AffectedValues.erase(AVI);
}
- erase_value(AssumeHandles, CI);
+ AssumeHandles.erase({CI, this});
+}
+
+void AssumptionCache::AssumeHandle::deleted() {
+ AC->AssumeHandles.erase(*this);
+ // 'this' now dangles!
}
void AssumptionCache::AffectedValueCallbackVH::deleted() {
@@ -204,14 +209,14 @@ void AssumptionCache::scanFunction() {
for (BasicBlock &B : F)
for (Instruction &II : B)
if (match(&II, m_Intrinsic<Intrinsic::assume>()))
- AssumeHandles.push_back({&II, ExprResultIdx});
+ AssumeHandles.insert({&II, this});
// Mark the scan as complete.
Scanned = true;
// Update affected values.
- for (auto &A : AssumeHandles)
- updateAffectedValues(cast<CallInst>(A));
+ for (auto &AssumeVH : AssumeHandles)
+ updateAffectedValues(AssumeVH.getAssumeCI());
}
void AssumptionCache::registerAssumption(CallInst *CI) {
@@ -223,7 +228,7 @@ void AssumptionCache::registerAssumption(CallInst *CI) {
if (!Scanned)
return;
- AssumeHandles.push_back({CI, ExprResultIdx});
+ AssumeHandles.insert({CI, this});
#ifndef NDEBUG
assert(CI->getParent() &&
@@ -231,20 +236,11 @@ void AssumptionCache::registerAssumption(CallInst *CI) {
assert(&F == CI->getParent()->getParent() &&
"Cannot register @llvm.assume call not in this function");
- // We expect the number of assumptions to be small, so in an asserts build
- // check that we don't accumulate duplicates and that all assumptions point
- // to the same function.
- SmallPtrSet<Value *, 16> AssumptionSet;
- for (auto &VH : AssumeHandles) {
- if (!VH)
- continue;
-
- assert(&F == cast<Instruction>(VH)->getParent()->getParent() &&
+ for (auto &AssumeVH : AssumeHandles) {
+ assert(&F == AssumeVH.getAssumeCI()->getCaller() &&
"Cached assumption not inside this function!");
- assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) &&
+ assert(match(AssumeVH.getAssumeCI(), m_Intrinsic<Intrinsic::assume>()) &&
"Cached something other than a call to @llvm.assume!");
- assert(AssumptionSet.insert(VH).second &&
- "Cache contains multiple copies of a call!");
}
#endif
@@ -258,9 +254,8 @@ PreservedAnalyses AssumptionPrinterPass::run(Function &F,
AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
OS << "Cached assumptions for function: " << F.getName() << "\n";
- for (auto &VH : AC.assumptions())
- if (VH)
- OS << " " << *cast<CallInst>(VH)->getArgOperand(0) << "\n";
+ for (auto &AssumeVH : AC.assumptions())
+ OS << " " << *AssumeVH.getAssumeCI()->getArgOperand(0) << "\n";
return PreservedAnalyses::all();
}
@@ -306,9 +301,8 @@ void AssumptionCacheTracker::verifyAnalysis() const {
SmallPtrSet<const CallInst *, 4> AssumptionSet;
for (const auto &I : AssumptionCaches) {
- for (auto &VH : I.second->assumptions())
- if (VH)
- AssumptionSet.insert(cast<CallInst>(VH));
+ for (auto &AssumeVH : I.second->assumptions())
+ AssumptionSet.insert(AssumeVH.getAssumeCI());
for (const BasicBlock &B : cast<Function>(*I.first))
for (const Instruction &II : B)
diff --git a/llvm/lib/Analysis/CodeMetrics.cpp b/llvm/lib/Analysis/CodeMetrics.cpp
index 157811c04eb5..b0b46cbdbdcc 100644
--- a/llvm/lib/Analysis/CodeMetrics.cpp
+++ b/llvm/lib/Analysis/CodeMetrics.cpp
@@ -73,9 +73,7 @@ void CodeMetrics::collectEphemeralValues(
SmallVector<const Value *, 16> Worklist;
for (auto &AssumeVH : AC->assumptions()) {
- if (!AssumeVH)
- continue;
- Instruction *I = cast<Instruction>(AssumeVH);
+ Instruction *I = AssumeVH.getAssumeCI();
// Filter out call sites outside of the loop so we don't do a function's
// worth of work for each of its loops (and, in the common case, ephemeral
@@ -97,9 +95,7 @@ void CodeMetrics::collectEphemeralValues(
SmallVector<const Value *, 16> Worklist;
for (auto &AssumeVH : AC->assumptions()) {
- if (!AssumeVH)
- continue;
- Instruction *I = cast<Instruction>(AssumeVH);
+ Instruction *I = AssumeVH.getAssumeCI();
assert(I->getParent()->getParent() == F &&
"Found assumption for the wrong function!");
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 484d3387acba..b207fc89f89b 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -1704,9 +1704,9 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
getZeroExtendExpr(Step, Ty, Depth + 1), L,
AR->getNoWrapFlags());
}
-
+
// For a negative step, we can extend the operands iff doing so only
- // traverses values in the range zext([0,UINT_MAX]).
+ // traverses values in the range zext([0,UINT_MAX]).
if (isKnownNegative(Step)) {
const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
getSignedRangeMin(Step));
@@ -9927,9 +9927,7 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
// Check conditions due to any @llvm.assume intrinsics.
for (auto &AssumeVH : AC.assumptions()) {
- if (!AssumeVH)
- continue;
- auto *CI = cast<CallInst>(AssumeVH);
+ auto *CI = AssumeVH.getAssumeCI();
if (!DT.dominates(CI, Latch->getTerminator()))
continue;
@@ -10076,9 +10074,7 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
// Check conditions due to any @llvm.assume intrinsics.
for (auto &AssumeVH : AC.assumptions()) {
- if (!AssumeVH)
- continue;
- auto *CI = cast<CallInst>(AssumeVH);
+ auto *CI = AssumeVH.getAssumeCI();
if (!DT.dominates(CI, BB))
continue;
@@ -13358,9 +13354,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
// Also collect information from assumptions dominating the loop.
for (auto &AssumeVH : AC.assumptions()) {
- if (!AssumeVH)
- continue;
- auto *AssumeI = cast<CallInst>(AssumeVH);
+ auto *AssumeI = AssumeVH.getAssumeCI();
auto *Cmp = dyn_cast<ICmpInst>(AssumeI->getOperand(0));
if (!Cmp || !DT.dominates(AssumeI, L->getHeader()))
continue;
diff --git a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
index bccf94fc217f..469060a93146 100644
--- a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
+++ b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
@@ -331,12 +331,11 @@ bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
DT = DT_;
bool Changed = false;
- for (auto &AssumeVH : AC.assumptions())
- if (AssumeVH) {
- CallInst *Call = cast<CallInst>(AssumeVH);
- for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++)
- Changed |= processAssumption(Call, Idx);
- }
+ for (auto &AssumeVH : AC.assumptions()) {
+ CallInst *Call = AssumeVH.getAssumeCI();
+ for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++)
+ Changed |= processAssumption(Call, Idx);
+ }
return Changed;
}
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index 70a1257e5901..461be91f2623 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -1781,10 +1781,8 @@ CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) {
bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
const Function &NewFunc,
AssumptionCache *AC) {
- for (auto AssumeVH : AC->assumptions()) {
- auto *I = dyn_cast_or_null<CallInst>(AssumeVH);
- if (!I)
- continue;
+ for (auto &AssumeVH : AC->assumptions()) {
+ auto *I = AssumeVH.getAssumeCI();
// There shouldn't be any llvm.assume intrinsics in the new function.
if (I->getFunction() != &OldFunc)
diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
index 3312a6f9459b..af5a72aa6ad9 100644
--- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp
+++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp
@@ -532,10 +532,11 @@ void PredicateInfoBuilder::buildPredicateInfo() {
processSwitch(SI, BranchBB, OpsToRename);
}
}
- for (auto &Assume : AC.assumptions()) {
- if (auto *II = dyn_cast_or_null<IntrinsicInst>(Assume))
- if (DT.isReachableFromEntry(II->getParent()))
- processAssume(II, II->getParent(), OpsToRename);
+ for (auto &AssumeVH : AC.assumptions()) {
+ CallInst *AssumeCI = AssumeVH.getAssumeCI();
+ if (DT.isReachableFromEntry(AssumeCI->getParent()))
+ processAssume(cast<IntrinsicInst>(AssumeCI), AssumeCI->getParent(),
+ OpsToRename);
}
// Now rename all our operations.
renameUses(OpsToRename);
diff --git a/llvm/test/Analysis/AssumptionCache/basic.ll b/llvm/test/Analysis/AssumptionCache/basic.ll
index bd4e7b6449fb..161fe10ed04b 100644
--- a/llvm/test/Analysis/AssumptionCache/basic.ll
+++ b/llvm/test/Analysis/AssumptionCache/basic.ll
@@ -6,9 +6,9 @@ declare void @llvm.assume(i1)
define void @test1(i32 %a) {
; CHECK-LABEL: Cached assumptions for function: test1
-; CHECK-NEXT: icmp ne i32 %{{.*}}, 0
-; CHECK-NEXT: icmp slt i32 %{{.*}}, 0
-; CHECK-NEXT: icmp sgt i32 %{{.*}}, 0
+; CHECK-DAG: icmp ne i32 %{{.*}}, 0
+; CHECK-DAG: icmp slt i32 %{{.*}}, 0
+; CHECK-DAG: icmp sgt i32 %{{.*}}, 0
entry:
%cond1 = icmp ne i32 %a, 0
More information about the llvm-commits
mailing list