[llvm] 813f438 - [AssumeBundles] adapt Assumption cache to assume bundles
via llvm-commits
llvm-commits at lists.llvm.org
Mon Apr 13 03:12:30 PDT 2020
Author: Tyker
Date: 2020-04-13T12:04:51+02:00
New Revision: 813f438baaa9638529023b2218875e01ea037735
URL: https://github.com/llvm/llvm-project/commit/813f438baaa9638529023b2218875e01ea037735
DIFF: https://github.com/llvm/llvm-project/commit/813f438baaa9638529023b2218875e01ea037735.diff
LOG: [AssumeBundles] adapt Assumption cache to assume bundles
Summary: change assumption cache to store an assume along with an index to the operand bundle containing the knowledge.
Reviewers: jdoerfert, hfinkel
Reviewed By: jdoerfert
Subscribers: hiraditya, mgrang, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D77402
Added:
Modified:
llvm/include/llvm/Analysis/AssumptionCache.h
llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h
llvm/lib/Analysis/AssumptionCache.cpp
llvm/lib/Transforms/Scalar/EarlyCSE.cpp
llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
llvm/lib/Transforms/Utils/InlineFunction.cpp
llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/AssumptionCache.h b/llvm/include/llvm/Analysis/AssumptionCache.h
index 0efbd59023d6..0ef63dc68e1c 100644
--- a/llvm/include/llvm/Analysis/AssumptionCache.h
+++ b/llvm/include/llvm/Analysis/AssumptionCache.h
@@ -39,6 +39,21 @@ class Value;
/// register any new \@llvm.assume calls that they create. Deletions of
/// \@llvm.assume calls do not require special handling.
class AssumptionCache {
+public:
+ /// Value of ResultElem::Index indicating that the argument to the call of the
+ /// llvm.assume.
+ enum : unsigned { ExprResultIdx = std::numeric_limits<unsigned>::max() };
+
+ struct ResultElem {
+ WeakTrackingVH Assume;
+
+ /// contains either ExprResultIdx or the index of the operand bundle
+ /// containing the knowledge.
+ unsigned Index;
+ operator Value *() const { return Assume; }
+ };
+
+private:
/// The function for which this cache is handling assumptions.
///
/// We track this to lazily populate our assumptions.
@@ -46,7 +61,7 @@ class AssumptionCache {
/// Vector of weak value handles to calls of the \@llvm.assume
/// intrinsic.
- SmallVector<WeakTrackingVH, 4> AssumeHandles;
+ SmallVector<ResultElem, 4> AssumeHandles;
class AffectedValueCallbackVH final : public CallbackVH {
AssumptionCache *AC;
@@ -66,12 +81,12 @@ class AssumptionCache {
/// A map of values about which an assumption might be providing
/// information to the relevant set of assumptions.
using AffectedValuesMap =
- DenseMap<AffectedValueCallbackVH, SmallVector<WeakTrackingVH, 1>,
+ DenseMap<AffectedValueCallbackVH, SmallVector<ResultElem, 1>,
AffectedValueCallbackVH::DMI>;
AffectedValuesMap AffectedValues;
/// Get the vector of assumptions which affect a value from the cache.
- SmallVector<WeakTrackingVH, 1> &getOrInsertAffectedValues(Value *V);
+ SmallVector<ResultElem, 1> &getOrInsertAffectedValues(Value *V);
/// Move affected values in the cache for OV to be affected values for NV.
void transferAffectedValuesInCache(Value *OV, Value *NV);
@@ -128,20 +143,20 @@ class AssumptionCache {
/// FIXME: We should replace this with pointee_iterator<filter_iterator<...>>
/// when we can write that to filter out the null values. Then caller code
/// will become simpler.
- MutableArrayRef<WeakTrackingVH> assumptions() {
+ MutableArrayRef<ResultElem> assumptions() {
if (!Scanned)
scanFunction();
return AssumeHandles;
}
/// Access the list of assumptions which affect this value.
- MutableArrayRef<WeakTrackingVH> assumptionsFor(const Value *V) {
+ MutableArrayRef<ResultElem> assumptionsFor(const Value *V) {
if (!Scanned)
scanFunction();
auto AVI = AffectedValues.find_as(const_cast<Value *>(V));
if (AVI == AffectedValues.end())
- return MutableArrayRef<WeakTrackingVH>();
+ return MutableArrayRef<ResultElem>();
return AVI->second;
}
@@ -234,6 +249,21 @@ class AssumptionCacheTracker : public ImmutablePass {
static char ID; // Pass identification, replacement for typeid
};
+template<> struct simplify_type<AssumptionCache::ResultElem> {
+ using SimpleType = Value *;
+
+ static SimpleType getSimplifiedValue(AssumptionCache::ResultElem &Val) {
+ return Val;
+ }
+};
+template<> struct simplify_type<const AssumptionCache::ResultElem> {
+ using SimpleType = /*const*/ Value *;
+
+ static SimpleType getSimplifiedValue(const AssumptionCache::ResultElem &Val) {
+ return Val;
+ }
+};
+
} // end namespace llvm
#endif // LLVM_ANALYSIS_ASSUMPTIONCACHE_H
diff --git a/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h b/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h
index 50f89194411a..f1cd0239b008 100644
--- a/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h
+++ b/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h
@@ -22,6 +22,7 @@
namespace llvm {
class IntrinsicInst;
+class AssumptionCache;
/// Build a call to llvm.assume to preserve informations that can be derived
/// from the given instruction.
@@ -32,7 +33,7 @@ IntrinsicInst *buildAssumeFromInst(Instruction *I);
/// Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert
/// if before I. This is usually what need to be done to salvage the knowledge
/// contained in the instruction I.
-void salvageKnowledge(Instruction *I);
+void salvageKnowledge(Instruction *I, AssumptionCache *AC = nullptr);
/// This pass will try to build an llvm.assume for every instruction in the
/// function. Its main purpose is testing.
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index f4d4a5ac8f88..205d758ffc14 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/Analysis/AssumeBundleQueries.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -41,7 +42,7 @@ static cl::opt<bool>
cl::desc("Enable verification of assumption cache"),
cl::init(false));
-SmallVector<WeakTrackingVH, 1> &
+SmallVector<AssumptionCache::ResultElem, 1> &
AssumptionCache::getOrInsertAffectedValues(Value *V) {
// Try using find_as first to avoid creating extra value handles just for the
// purpose of doing the lookup.
@@ -50,32 +51,39 @@ AssumptionCache::getOrInsertAffectedValues(Value *V) {
return AVI->second;
auto AVIP = AffectedValues.insert(
- {AffectedValueCallbackVH(V, this), SmallVector<WeakTrackingVH, 1>()});
+ {AffectedValueCallbackVH(V, this), SmallVector<ResultElem, 1>()});
return AVIP.first->second;
}
-static void findAffectedValues(CallInst *CI,
- SmallVectorImpl<Value *> &Affected) {
+static void
+findAffectedValues(CallInst *CI,
+ SmallVectorImpl<AssumptionCache::ResultElem> &Affected) {
// Note: This code must be kept in-sync with the code in
// computeKnownBitsFromAssume in ValueTracking.
- auto AddAffected = [&Affected](Value *V) {
+ auto AddAffected = [&Affected](Value *V, unsigned Idx =
+ AssumptionCache::ExprResultIdx) {
if (isa<Argument>(V)) {
- Affected.push_back(V);
+ Affected.push_back({V, Idx});
} else if (auto *I = dyn_cast<Instruction>(V)) {
- Affected.push_back(I);
+ Affected.push_back({I, Idx});
// Peek through unary operators to find the source of the condition.
Value *Op;
if (match(I, m_BitCast(m_Value(Op))) ||
- match(I, m_PtrToInt(m_Value(Op))) ||
- match(I, m_Not(m_Value(Op)))) {
+ match(I, m_PtrToInt(m_Value(Op))) || match(I, m_Not(m_Value(Op)))) {
if (isa<Instruction>(Op) || isa<Argument>(Op))
- Affected.push_back(Op);
+ Affected.push_back({Op, Idx});
}
}
};
+ for (unsigned Idx = 0; Idx != CI->getNumOperandBundles(); Idx++) {
+ if (CI->getOperandBundleAt(Idx).Inputs.size() > ABA_WasOn &&
+ CI->getOperandBundleAt(Idx).getTagName() != "ignore")
+ AddAffected(CI->getOperandBundleAt(Idx).Inputs[ABA_WasOn], Idx);
+ }
+
Value *Cond = CI->getArgOperand(0), *A, *B;
AddAffected(Cond);
@@ -112,28 +120,44 @@ static void findAffectedValues(CallInst *CI,
}
void AssumptionCache::updateAffectedValues(CallInst *CI) {
- SmallVector<Value *, 16> Affected;
+ SmallVector<AssumptionCache::ResultElem, 16> Affected;
findAffectedValues(CI, Affected);
for (auto &AV : Affected) {
- auto &AVV = getOrInsertAffectedValues(AV);
- if (std::find(AVV.begin(), AVV.end(), CI) == AVV.end())
- AVV.push_back(CI);
+ auto &AVV = getOrInsertAffectedValues(AV.Assume);
+ if (std::find_if(AVV.begin(), AVV.end(), [&](ResultElem &Elem) {
+ return Elem.Assume == CI && Elem.Index == AV.Index;
+ }) == AVV.end())
+ AVV.push_back({CI, AV.Index});
}
}
void AssumptionCache::unregisterAssumption(CallInst *CI) {
- SmallVector<Value *, 16> Affected;
+ SmallVector<AssumptionCache::ResultElem, 16> Affected;
findAffectedValues(CI, Affected);
for (auto &AV : Affected) {
- auto AVI = AffectedValues.find_as(AV);
- if (AVI != AffectedValues.end())
+ auto AVI = AffectedValues.find_as(AV.Assume);
+ if (AVI == AffectedValues.end())
+ continue;
+ bool Found = false;
+ bool HasNonnull = false;
+ for (ResultElem &Elem : AVI->second) {
+ if (Elem.Assume == CI) {
+ Found = true;
+ Elem.Assume = nullptr;
+ }
+ HasNonnull |= !!Elem.Assume;
+ if (HasNonnull && Found)
+ break;
+ }
+ assert(Found && "already unregistered or incorrect cache state");
+ if (!HasNonnull)
AffectedValues.erase(AVI);
}
AssumeHandles.erase(
- remove_if(AssumeHandles, [CI](WeakTrackingVH &VH) { return CI == VH; }),
+ remove_if(AssumeHandles, [CI](ResultElem &RE) { return CI == RE; }),
AssumeHandles.end());
}
@@ -177,7 +201,7 @@ void AssumptionCache::scanFunction() {
for (BasicBlock &B : F)
for (Instruction &II : B)
if (match(&II, m_Intrinsic<Intrinsic::assume>()))
- AssumeHandles.push_back(&II);
+ AssumeHandles.push_back({&II, ExprResultIdx});
// Mark the scan as complete.
Scanned = true;
@@ -196,7 +220,7 @@ void AssumptionCache::registerAssumption(CallInst *CI) {
if (!Scanned)
return;
- AssumeHandles.push_back(CI);
+ AssumeHandles.push_back({CI, ExprResultIdx});
#ifndef NDEBUG
assert(CI->getParent() &&
diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index fb9db5657db2..d8636d3849c6 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -948,7 +948,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
continue;
}
- salvageKnowledge(&Inst);
+ salvageKnowledge(&Inst, &AC);
salvageDebugInfoOrMarkUndef(Inst);
removeMSSA(Inst);
Inst.eraseFromParent();
@@ -1015,7 +1015,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
cast<ConstantInt>(KnownCond)->isOne()) {
LLVM_DEBUG(dbgs()
<< "EarlyCSE removing guard: " << Inst << '\n');
- salvageKnowledge(&Inst);
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@@ -1051,7 +1051,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
Changed = true;
}
if (isInstructionTriviallyDead(&Inst, &TLI)) {
- salvageKnowledge(&Inst);
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@@ -1077,7 +1077,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
if (auto *I = dyn_cast<Instruction>(V))
I->andIRFlags(&Inst);
Inst.replaceAllUsesWith(V);
- salvageKnowledge(&Inst);
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@@ -1138,7 +1138,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
}
if (!Inst.use_empty())
Inst.replaceAllUsesWith(Op);
- salvageKnowledge(&Inst);
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@@ -1182,7 +1182,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
}
if (!Inst.use_empty())
Inst.replaceAllUsesWith(InVal.first);
- salvageKnowledge(&Inst);
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@@ -1235,7 +1235,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
continue;
}
- salvageKnowledge(&Inst);
+ salvageKnowledge(&Inst, &AC);
removeMSSA(Inst);
Inst.eraseFromParent();
Changed = true;
@@ -1271,7 +1271,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
if (!DebugCounter::shouldExecute(CSECounter)) {
LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
} else {
- salvageKnowledge(&Inst);
+ salvageKnowledge(&Inst, &AC);
removeMSSA(*LastStore);
LastStore->eraseFromParent();
Changed = true;
diff --git a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
index cd677f7d0fb7..682e69f1e662 100644
--- a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
+++ b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
@@ -8,6 +8,7 @@
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Analysis/AssumeBundleQueries.h"
+#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
@@ -222,9 +223,12 @@ IntrinsicInst *llvm::buildAssumeFromInst(Instruction *I) {
return Builder.build();
}
-void llvm::salvageKnowledge(Instruction *I) {
- if (Instruction *Intr = buildAssumeFromInst(I))
+void llvm::salvageKnowledge(Instruction *I, AssumptionCache *AC) {
+ if (IntrinsicInst *Intr = buildAssumeFromInst(I)) {
Intr->insertBefore(I);
+ if (AC)
+ AC->registerAssumption(Intr);
+ }
}
PreservedAnalyses AssumeBuilderPass::run(Function &F,
diff --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp
index d7dd342004e2..593b4bc889be 100644
--- a/llvm/lib/Transforms/Utils/InlineFunction.cpp
+++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp
@@ -1837,9 +1837,11 @@ llvm::InlineResult llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,
// check what will be known at the start of the inlined code.
AddAlignmentAssumptions(CS, IFI);
+ AssumptionCache *AC =
+ IFI.GetAssumptionCache ? &(*IFI.GetAssumptionCache)(*Caller) : nullptr;
+
/// Preserve all attributes on of the call and its parameters.
- if (Instruction *Assume = buildAssumeFromInst(CS.getInstruction()))
- Assume->insertBefore(CS.getInstruction());
+ salvageKnowledge(CS.getInstruction(), AC);
// We want the inliner to prune the code as it copies. We would LOVE to
// have no dead or constant instructions leftover after inlining occurs
diff --git a/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp b/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp
index f6f4849a94e8..62293b997d5d 100644
--- a/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp
+++ b/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/AssumeBundleQueries.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/CallSite.h"
@@ -510,3 +511,66 @@ TEST(AssumeQueryAPI, getKnowledgeFromUseInAssume) {
// large.
RunRandTest(9876789, 100000, -0, 7, 100);
}
+
+TEST(AssumeQueryAPI, AssumptionCache) {
+ LLVMContext C;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> Mod = parseAssemblyString(
+ "declare void @llvm.assume(i1)\n"
+ "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3, i1 %B) {\n"
+ "call void @llvm.assume(i1 true) [\"nonnull\"(i32* %P), \"align\"(i32* "
+ "%P2, i32 4), \"align\"(i32* %P, i32 8)]\n"
+ "call void @llvm.assume(i1 %B) [\"test\"(i32* %P1), "
+ "\"dereferenceable\"(i32* %P, i32 4)]\n"
+ "ret void\n}\n",
+ Err, C);
+ if (!Mod)
+ Err.print("AssumeQueryAPI", errs());
+ Function *F = Mod->getFunction("test");
+ BasicBlock::iterator First = F->begin()->begin();
+ BasicBlock::iterator Second = F->begin()->begin();
+ Second++;
+ AssumptionCacheTracker ACT;
+ AssumptionCache &AC = ACT.getAssumptionCache(*F);
+ auto AR = AC.assumptionsFor(F->getArg(3));
+ ASSERT_EQ(AR.size(), 0u);
+ AR = AC.assumptionsFor(F->getArg(1));
+ ASSERT_EQ(AR.size(), 1u);
+ ASSERT_EQ(AR[0].Index, 0u);
+ ASSERT_EQ(AR[0].Assume, &*Second);
+ AR = AC.assumptionsFor(F->getArg(2));
+ ASSERT_EQ(AR.size(), 1u);
+ ASSERT_EQ(AR[0].Index, 1u);
+ ASSERT_EQ(AR[0].Assume, &*First);
+ AR = AC.assumptionsFor(F->getArg(0));
+ ASSERT_EQ(AR.size(), 3u);
+ llvm::sort(AR,
+ [](const auto &L, const auto &R) { return L.Index < R.Index; });
+ ASSERT_EQ(AR[0].Assume, &*First);
+ ASSERT_EQ(AR[0].Index, 0u);
+ ASSERT_EQ(AR[1].Assume, &*Second);
+ ASSERT_EQ(AR[1].Index, 1u);
+ ASSERT_EQ(AR[2].Assume, &*First);
+ ASSERT_EQ(AR[2].Index, 2u);
+ AR = AC.assumptionsFor(F->getArg(4));
+ ASSERT_EQ(AR.size(), 1u);
+ ASSERT_EQ(AR[0].Assume, &*Second);
+ ASSERT_EQ(AR[0].Index, AssumptionCache::ExprResultIdx);
+ AC.unregisterAssumption(cast<CallInst>(&*Second));
+ AR = AC.assumptionsFor(F->getArg(1));
+ ASSERT_EQ(AR.size(), 0u);
+ AR = AC.assumptionsFor(F->getArg(0));
+ ASSERT_EQ(AR.size(), 3u);
+ llvm::sort(AR,
+ [](const auto &L, const auto &R) { return L.Index < R.Index; });
+ ASSERT_EQ(AR[0].Assume, &*First);
+ ASSERT_EQ(AR[0].Index, 0u);
+ ASSERT_EQ(AR[1].Assume, nullptr);
+ ASSERT_EQ(AR[1].Index, 1u);
+ ASSERT_EQ(AR[2].Assume, &*First);
+ ASSERT_EQ(AR[2].Index, 2u);
+ AR = AC.assumptionsFor(F->getArg(2));
+ ASSERT_EQ(AR.size(), 1u);
+ ASSERT_EQ(AR[0].Index, 1u);
+ ASSERT_EQ(AR[0].Assume, &*First);
+}
More information about the llvm-commits
mailing list