[llvm] 813f438 - [AssumeBundles] adapt Assumption cache to assume bundles

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 13 03:12:30 PDT 2020


Author: Tyker
Date: 2020-04-13T12:04:51+02:00
New Revision: 813f438baaa9638529023b2218875e01ea037735

URL: https://github.com/llvm/llvm-project/commit/813f438baaa9638529023b2218875e01ea037735
DIFF: https://github.com/llvm/llvm-project/commit/813f438baaa9638529023b2218875e01ea037735.diff

LOG: [AssumeBundles] adapt Assumption cache to assume bundles

Summary: change assumption cache to store an assume along with an index to the operand bundle containing the knowledge.

Reviewers: jdoerfert, hfinkel

Reviewed By: jdoerfert

Subscribers: hiraditya, mgrang, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77402

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/AssumptionCache.h
    llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h
    llvm/lib/Analysis/AssumptionCache.cpp
    llvm/lib/Transforms/Scalar/EarlyCSE.cpp
    llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
    llvm/lib/Transforms/Utils/InlineFunction.cpp
    llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/AssumptionCache.h b/llvm/include/llvm/Analysis/AssumptionCache.h
index 0efbd59023d6..0ef63dc68e1c 100644
--- a/llvm/include/llvm/Analysis/AssumptionCache.h
+++ b/llvm/include/llvm/Analysis/AssumptionCache.h
@@ -39,6 +39,21 @@ class Value;
 /// register any new \@llvm.assume calls that they create. Deletions of
 /// \@llvm.assume calls do not require special handling.
 class AssumptionCache {
+public:
+  /// Value of ResultElem::Index indicating that the argument to the call of the
+  /// llvm.assume.
+  enum : unsigned { ExprResultIdx = std::numeric_limits<unsigned>::max() };
+
+  struct ResultElem {
+    WeakTrackingVH Assume;
+
+    /// contains either ExprResultIdx or the index of the operand bundle
+    /// containing the knowledge.
+    unsigned Index;
+    operator Value *() const { return Assume; }
+  };
+
+private:
   /// The function for which this cache is handling assumptions.
   ///
   /// We track this to lazily populate our assumptions.
@@ -46,7 +61,7 @@ class AssumptionCache {
 
   /// Vector of weak value handles to calls of the \@llvm.assume
   /// intrinsic.
-  SmallVector<WeakTrackingVH, 4> AssumeHandles;
+  SmallVector<ResultElem, 4> AssumeHandles;
 
   class AffectedValueCallbackVH final : public CallbackVH {
     AssumptionCache *AC;
@@ -66,12 +81,12 @@ class AssumptionCache {
   /// A map of values about which an assumption might be providing
   /// information to the relevant set of assumptions.
   using AffectedValuesMap =
-      DenseMap<AffectedValueCallbackVH, SmallVector<WeakTrackingVH, 1>,
+      DenseMap<AffectedValueCallbackVH, SmallVector<ResultElem, 1>,
                AffectedValueCallbackVH::DMI>;
   AffectedValuesMap AffectedValues;
 
   /// Get the vector of assumptions which affect a value from the cache.
-  SmallVector<WeakTrackingVH, 1> &getOrInsertAffectedValues(Value *V);
+  SmallVector<ResultElem, 1> &getOrInsertAffectedValues(Value *V);
 
   /// Move affected values in the cache for OV to be affected values for NV.
   void transferAffectedValuesInCache(Value *OV, Value *NV);
@@ -128,20 +143,20 @@ class AssumptionCache {
   /// FIXME: We should replace this with pointee_iterator<filter_iterator<...>>
   /// when we can write that to filter out the null values. Then caller code
   /// will become simpler.
-  MutableArrayRef<WeakTrackingVH> assumptions() {
+  MutableArrayRef<ResultElem> assumptions() {
     if (!Scanned)
       scanFunction();
     return AssumeHandles;
   }
 
   /// Access the list of assumptions which affect this value.
-  MutableArrayRef<WeakTrackingVH> assumptionsFor(const Value *V) {
+  MutableArrayRef<ResultElem> assumptionsFor(const Value *V) {
     if (!Scanned)
       scanFunction();
 
     auto AVI = AffectedValues.find_as(const_cast<Value *>(V));
     if (AVI == AffectedValues.end())
-      return MutableArrayRef<WeakTrackingVH>();
+      return MutableArrayRef<ResultElem>();
 
     return AVI->second;
   }
@@ -234,6 +249,21 @@ class AssumptionCacheTracker : public ImmutablePass {
   static char ID; // Pass identification, replacement for typeid
 };
 
+template<> struct simplify_type<AssumptionCache::ResultElem> {
+  using SimpleType = Value *;
+
+  static SimpleType getSimplifiedValue(AssumptionCache::ResultElem &Val) {
+    return Val;
+  }
+};
+template<> struct simplify_type<const AssumptionCache::ResultElem> {
+  using SimpleType = /*const*/ Value *;
+
+  static SimpleType getSimplifiedValue(const AssumptionCache::ResultElem &Val) {
+    return Val;
+  }
+};
+
 } // end namespace llvm
 
 #endif // LLVM_ANALYSIS_ASSUMPTIONCACHE_H

diff  --git a/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h b/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h
index 50f89194411a..f1cd0239b008 100644
--- a/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h
+++ b/llvm/include/llvm/Transforms/Utils/AssumeBundleBuilder.h
@@ -22,6 +22,7 @@
 
 namespace llvm {
 class IntrinsicInst;
+class AssumptionCache;
 
 /// Build a call to llvm.assume to preserve informations that can be derived
 /// from the given instruction.
@@ -32,7 +33,7 @@ IntrinsicInst *buildAssumeFromInst(Instruction *I);
 /// Calls BuildAssumeFromInst and if the resulting llvm.assume is valid insert
 /// if before I. This is usually what need to be done to salvage the knowledge
 /// contained in the instruction I.
-void salvageKnowledge(Instruction *I);
+void salvageKnowledge(Instruction *I, AssumptionCache *AC = nullptr);
 
 /// This pass will try to build an llvm.assume for every instruction in the
 /// function. Its main purpose is testing.

diff  --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index f4d4a5ac8f88..205d758ffc14 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/Analysis/AssumeBundleQueries.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -41,7 +42,7 @@ static cl::opt<bool>
                           cl::desc("Enable verification of assumption cache"),
                           cl::init(false));
 
-SmallVector<WeakTrackingVH, 1> &
+SmallVector<AssumptionCache::ResultElem, 1> &
 AssumptionCache::getOrInsertAffectedValues(Value *V) {
   // Try using find_as first to avoid creating extra value handles just for the
   // purpose of doing the lookup.
@@ -50,32 +51,39 @@ AssumptionCache::getOrInsertAffectedValues(Value *V) {
     return AVI->second;
 
   auto AVIP = AffectedValues.insert(
-      {AffectedValueCallbackVH(V, this), SmallVector<WeakTrackingVH, 1>()});
+      {AffectedValueCallbackVH(V, this), SmallVector<ResultElem, 1>()});
   return AVIP.first->second;
 }
 
-static void findAffectedValues(CallInst *CI,
-                               SmallVectorImpl<Value *> &Affected) {
+static void
+findAffectedValues(CallInst *CI,
+                   SmallVectorImpl<AssumptionCache::ResultElem> &Affected) {
   // Note: This code must be kept in-sync with the code in
   // computeKnownBitsFromAssume in ValueTracking.
 
-  auto AddAffected = [&Affected](Value *V) {
+  auto AddAffected = [&Affected](Value *V, unsigned Idx =
+                                               AssumptionCache::ExprResultIdx) {
     if (isa<Argument>(V)) {
-      Affected.push_back(V);
+      Affected.push_back({V, Idx});
     } else if (auto *I = dyn_cast<Instruction>(V)) {
-      Affected.push_back(I);
+      Affected.push_back({I, Idx});
 
       // Peek through unary operators to find the source of the condition.
       Value *Op;
       if (match(I, m_BitCast(m_Value(Op))) ||
-          match(I, m_PtrToInt(m_Value(Op))) ||
-          match(I, m_Not(m_Value(Op)))) {
+          match(I, m_PtrToInt(m_Value(Op))) || match(I, m_Not(m_Value(Op)))) {
         if (isa<Instruction>(Op) || isa<Argument>(Op))
-          Affected.push_back(Op);
+          Affected.push_back({Op, Idx});
       }
     }
   };
 
+  for (unsigned Idx = 0; Idx != CI->getNumOperandBundles(); Idx++) {
+    if (CI->getOperandBundleAt(Idx).Inputs.size() > ABA_WasOn &&
+        CI->getOperandBundleAt(Idx).getTagName() != "ignore")
+      AddAffected(CI->getOperandBundleAt(Idx).Inputs[ABA_WasOn], Idx);
+  }
+
   Value *Cond = CI->getArgOperand(0), *A, *B;
   AddAffected(Cond);
 
@@ -112,28 +120,44 @@ static void findAffectedValues(CallInst *CI,
 }
 
 void AssumptionCache::updateAffectedValues(CallInst *CI) {
-  SmallVector<Value *, 16> Affected;
+  SmallVector<AssumptionCache::ResultElem, 16> Affected;
   findAffectedValues(CI, Affected);
 
   for (auto &AV : Affected) {
-    auto &AVV = getOrInsertAffectedValues(AV);
-    if (std::find(AVV.begin(), AVV.end(), CI) == AVV.end())
-      AVV.push_back(CI);
+    auto &AVV = getOrInsertAffectedValues(AV.Assume);
+    if (std::find_if(AVV.begin(), AVV.end(), [&](ResultElem &Elem) {
+          return Elem.Assume == CI && Elem.Index == AV.Index;
+        }) == AVV.end())
+      AVV.push_back({CI, AV.Index});
   }
 }
 
 void AssumptionCache::unregisterAssumption(CallInst *CI) {
-  SmallVector<Value *, 16> Affected;
+  SmallVector<AssumptionCache::ResultElem, 16> Affected;
   findAffectedValues(CI, Affected);
 
   for (auto &AV : Affected) {
-    auto AVI = AffectedValues.find_as(AV);
-    if (AVI != AffectedValues.end())
+    auto AVI = AffectedValues.find_as(AV.Assume);
+    if (AVI == AffectedValues.end())
+      continue;
+    bool Found = false;
+    bool HasNonnull = false;
+    for (ResultElem &Elem : AVI->second) {
+      if (Elem.Assume == CI) {
+        Found = true;
+        Elem.Assume = nullptr;
+      }
+      HasNonnull |= !!Elem.Assume;
+      if (HasNonnull && Found)
+        break;
+    }
+    assert(Found && "already unregistered or incorrect cache state");
+    if (!HasNonnull)
       AffectedValues.erase(AVI);
   }
 
   AssumeHandles.erase(
-      remove_if(AssumeHandles, [CI](WeakTrackingVH &VH) { return CI == VH; }),
+      remove_if(AssumeHandles, [CI](ResultElem &RE) { return CI == RE; }),
       AssumeHandles.end());
 }
 
@@ -177,7 +201,7 @@ void AssumptionCache::scanFunction() {
   for (BasicBlock &B : F)
     for (Instruction &II : B)
       if (match(&II, m_Intrinsic<Intrinsic::assume>()))
-        AssumeHandles.push_back(&II);
+        AssumeHandles.push_back({&II, ExprResultIdx});
 
   // Mark the scan as complete.
   Scanned = true;
@@ -196,7 +220,7 @@ void AssumptionCache::registerAssumption(CallInst *CI) {
   if (!Scanned)
     return;
 
-  AssumeHandles.push_back(CI);
+  AssumeHandles.push_back({CI, ExprResultIdx});
 
 #ifndef NDEBUG
   assert(CI->getParent() &&

diff  --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
index fb9db5657db2..d8636d3849c6 100644
--- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -948,7 +948,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
         continue;
       }
 
-      salvageKnowledge(&Inst);
+      salvageKnowledge(&Inst, &AC);
       salvageDebugInfoOrMarkUndef(Inst);
       removeMSSA(Inst);
       Inst.eraseFromParent();
@@ -1015,7 +1015,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
                 cast<ConstantInt>(KnownCond)->isOne()) {
               LLVM_DEBUG(dbgs()
                          << "EarlyCSE removing guard: " << Inst << '\n');
-              salvageKnowledge(&Inst);
+              salvageKnowledge(&Inst, &AC);
               removeMSSA(Inst);
               Inst.eraseFromParent();
               Changed = true;
@@ -1051,7 +1051,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
           Changed = true;
         }
         if (isInstructionTriviallyDead(&Inst, &TLI)) {
-          salvageKnowledge(&Inst);
+          salvageKnowledge(&Inst, &AC);
           removeMSSA(Inst);
           Inst.eraseFromParent();
           Changed = true;
@@ -1077,7 +1077,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
         if (auto *I = dyn_cast<Instruction>(V))
           I->andIRFlags(&Inst);
         Inst.replaceAllUsesWith(V);
-        salvageKnowledge(&Inst);
+        salvageKnowledge(&Inst, &AC);
         removeMSSA(Inst);
         Inst.eraseFromParent();
         Changed = true;
@@ -1138,7 +1138,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
           }
           if (!Inst.use_empty())
             Inst.replaceAllUsesWith(Op);
-          salvageKnowledge(&Inst);
+          salvageKnowledge(&Inst, &AC);
           removeMSSA(Inst);
           Inst.eraseFromParent();
           Changed = true;
@@ -1182,7 +1182,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
         }
         if (!Inst.use_empty())
           Inst.replaceAllUsesWith(InVal.first);
-        salvageKnowledge(&Inst);
+        salvageKnowledge(&Inst, &AC);
         removeMSSA(Inst);
         Inst.eraseFromParent();
         Changed = true;
@@ -1235,7 +1235,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
           LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
           continue;
         }
-        salvageKnowledge(&Inst);
+        salvageKnowledge(&Inst, &AC);
         removeMSSA(Inst);
         Inst.eraseFromParent();
         Changed = true;
@@ -1271,7 +1271,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {
             if (!DebugCounter::shouldExecute(CSECounter)) {
               LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");
             } else {
-              salvageKnowledge(&Inst);
+              salvageKnowledge(&Inst, &AC);
               removeMSSA(*LastStore);
               LastStore->eraseFromParent();
               Changed = true;

diff  --git a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
index cd677f7d0fb7..682e69f1e662 100644
--- a/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
+++ b/llvm/lib/Transforms/Utils/AssumeBundleBuilder.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
 #include "llvm/Analysis/AssumeBundleQueries.h"
+#include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/InstIterator.h"
@@ -222,9 +223,12 @@ IntrinsicInst *llvm::buildAssumeFromInst(Instruction *I) {
   return Builder.build();
 }
 
-void llvm::salvageKnowledge(Instruction *I) {
-  if (Instruction *Intr = buildAssumeFromInst(I))
+void llvm::salvageKnowledge(Instruction *I, AssumptionCache *AC) {
+  if (IntrinsicInst *Intr = buildAssumeFromInst(I)) {
     Intr->insertBefore(I);
+    if (AC)
+      AC->registerAssumption(Intr);
+  }
 }
 
 PreservedAnalyses AssumeBuilderPass::run(Function &F,

diff  --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp
index d7dd342004e2..593b4bc889be 100644
--- a/llvm/lib/Transforms/Utils/InlineFunction.cpp
+++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp
@@ -1837,9 +1837,11 @@ llvm::InlineResult llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,
     // check what will be known at the start of the inlined code.
     AddAlignmentAssumptions(CS, IFI);
 
+    AssumptionCache *AC =
+        IFI.GetAssumptionCache ? &(*IFI.GetAssumptionCache)(*Caller) : nullptr;
+
     /// Preserve all attributes on of the call and its parameters.
-    if (Instruction *Assume = buildAssumeFromInst(CS.getInstruction()))
-      Assume->insertBefore(CS.getInstruction());
+    salvageKnowledge(CS.getInstruction(), AC);
 
     // We want the inliner to prune the code as it copies.  We would LOVE to
     // have no dead or constant instructions leftover after inlining occurs

diff  --git a/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp b/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp
index f6f4849a94e8..62293b997d5d 100644
--- a/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp
+++ b/llvm/unittests/Analysis/AssumeBundleQueriesTest.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/AssumeBundleQueries.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/CallSite.h"
@@ -510,3 +511,66 @@ TEST(AssumeQueryAPI, getKnowledgeFromUseInAssume) {
   // large.
   RunRandTest(9876789, 100000, -0, 7, 100);
 }
+
+TEST(AssumeQueryAPI, AssumptionCache) {
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> Mod = parseAssemblyString(
+      "declare void @llvm.assume(i1)\n"
+      "define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3, i1 %B) {\n"
+      "call void @llvm.assume(i1 true) [\"nonnull\"(i32* %P), \"align\"(i32* "
+      "%P2, i32 4), \"align\"(i32* %P, i32 8)]\n"
+      "call void @llvm.assume(i1 %B) [\"test\"(i32* %P1), "
+      "\"dereferenceable\"(i32* %P, i32 4)]\n"
+      "ret void\n}\n",
+      Err, C);
+  if (!Mod)
+    Err.print("AssumeQueryAPI", errs());
+  Function *F = Mod->getFunction("test");
+  BasicBlock::iterator First = F->begin()->begin();
+  BasicBlock::iterator Second = F->begin()->begin();
+  Second++;
+  AssumptionCacheTracker ACT;
+  AssumptionCache &AC = ACT.getAssumptionCache(*F);
+  auto AR = AC.assumptionsFor(F->getArg(3));
+  ASSERT_EQ(AR.size(), 0u);
+  AR = AC.assumptionsFor(F->getArg(1));
+  ASSERT_EQ(AR.size(), 1u);
+  ASSERT_EQ(AR[0].Index, 0u);
+  ASSERT_EQ(AR[0].Assume, &*Second);
+  AR = AC.assumptionsFor(F->getArg(2));
+  ASSERT_EQ(AR.size(), 1u);
+  ASSERT_EQ(AR[0].Index, 1u);
+  ASSERT_EQ(AR[0].Assume, &*First);
+  AR = AC.assumptionsFor(F->getArg(0));
+  ASSERT_EQ(AR.size(), 3u);
+  llvm::sort(AR,
+             [](const auto &L, const auto &R) { return L.Index < R.Index; });
+  ASSERT_EQ(AR[0].Assume, &*First);
+  ASSERT_EQ(AR[0].Index, 0u);
+  ASSERT_EQ(AR[1].Assume, &*Second);
+  ASSERT_EQ(AR[1].Index, 1u);
+  ASSERT_EQ(AR[2].Assume, &*First);
+  ASSERT_EQ(AR[2].Index, 2u);
+  AR = AC.assumptionsFor(F->getArg(4));
+  ASSERT_EQ(AR.size(), 1u);
+  ASSERT_EQ(AR[0].Assume, &*Second);
+  ASSERT_EQ(AR[0].Index, AssumptionCache::ExprResultIdx);
+  AC.unregisterAssumption(cast<CallInst>(&*Second));
+  AR = AC.assumptionsFor(F->getArg(1));
+  ASSERT_EQ(AR.size(), 0u);
+  AR = AC.assumptionsFor(F->getArg(0));
+  ASSERT_EQ(AR.size(), 3u);
+  llvm::sort(AR,
+             [](const auto &L, const auto &R) { return L.Index < R.Index; });
+  ASSERT_EQ(AR[0].Assume, &*First);
+  ASSERT_EQ(AR[0].Index, 0u);
+  ASSERT_EQ(AR[1].Assume, nullptr);
+  ASSERT_EQ(AR[1].Index, 1u);
+  ASSERT_EQ(AR[2].Assume, &*First);
+  ASSERT_EQ(AR[2].Index, 2u);
+  AR = AC.assumptionsFor(F->getArg(2));
+  ASSERT_EQ(AR.size(), 1u);
+  ASSERT_EQ(AR[0].Index, 1u);
+  ASSERT_EQ(AR[0].Assume, &*First);
+}


        


More information about the llvm-commits mailing list