[llvm] b2d1ae0 - [Attributor] AAFunctionReachability, Instruction reachability.

Johannes Doerfert via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 31 23:43:03 PST 2022


Author: Kuter Dinel
Date: 2022-02-01T01:40:44-06:00
New Revision: b2d1ae061153b2364f15f58ee00e46b49bad3544

URL: https://github.com/llvm/llvm-project/commit/b2d1ae061153b2364f15f58ee00e46b49bad3544
DIFF: https://github.com/llvm/llvm-project/commit/b2d1ae061153b2364f15f58ee00e46b49bad3544.diff

LOG: [Attributor] AAFunctionReachability, Instruction reachability.

This patch implement instruction reachability for AAFunctionReachability
attribute. It is used to tell if a certain instruction can reach a function
transitively.

NOTE: I created a new commit based of D106720 and set the author back to
      Kuter. Other metadata, etc. is wrong. I also addressed the
      remaining review comments and fixed the unit test.

Differential Revision: https://reviews.llvm.org/D106720

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/IPO/Attributor.h
    llvm/lib/Transforms/IPO/AttributorAttributes.cpp
    llvm/unittests/Transforms/IPO/AttributorTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index b31dceb1b7008..b225a08880d90 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -4616,17 +4616,25 @@ struct AAFunctionReachability
   AAFunctionReachability(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
 
   /// If the function represented by this possition can reach \p Fn.
-  virtual bool canReach(Attributor &A, Function *Fn) const = 0;
+  virtual bool canReach(Attributor &A, const Function &Fn) const = 0;
 
   /// Can \p CB reach \p Fn
-  virtual bool canReach(Attributor &A, CallBase &CB, Function *Fn) const = 0;
+  virtual bool canReach(Attributor &A, CallBase &CB,
+                        const Function &Fn) const = 0;
+
+  /// Can  \p Inst reach \p Fn
+  virtual bool instructionCanReach(Attributor &A, const Instruction &Inst,
+                                   const Function &Fn,
+                                   bool UseBackwards = true) const = 0;
 
   /// Create an abstract attribute view for the position \p IRP.
   static AAFunctionReachability &createForPosition(const IRPosition &IRP,
                                                    Attributor &A);
 
   /// See AbstractAttribute::getName()
-  const std::string getName() const override { return "AAFunctionReachability"; }
+  const std::string getName() const override {
+    return "AAFunctionReachability";
+  }
 
   /// See AbstractAttribute::getIdAddr()
   const char *getIdAddr() const override { return &ID; }

diff  --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 5aad4ae4742c4..92ce8c141ca6b 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -656,7 +656,7 @@ struct AACallSiteReturnedFromReturned : public BaseType {
     if (!AssociatedFunction)
       return S.indicatePessimisticFixpoint();
 
-    CallBase &CBContext = static_cast<CallBase &>(this->getAnchorValue());
+    CallBase &CBContext = cast<CallBase>(this->getAnchorValue());
     if (IntroduceCallBaseContext)
       LLVM_DEBUG(dbgs() << "[Attributor] Introducing call base context:"
                         << CBContext << "\n");
@@ -2468,7 +2468,7 @@ struct AANoRecurseFunction final : AANoRecurseImpl {
     const AAFunctionReachability &EdgeReachability =
         A.getAAFor<AAFunctionReachability>(*this, getIRPosition(),
                                            DepClassTy::REQUIRED);
-    if (EdgeReachability.canReach(A, getAnchorScope()))
+    if (EdgeReachability.canReach(A, *getAnchorScope()))
       return indicatePessimisticFixpoint();
     return ChangeStatus::UNCHANGED;
   }
@@ -9482,7 +9482,7 @@ struct AACallEdgesCallSite : public AACallEdgesImpl {
       }
     };
 
-    CallBase *CB = static_cast<CallBase *>(getCtxI());
+    CallBase *CB = cast<CallBase>(getCtxI());
 
     if (CB->isInlineAsm()) {
       setHasUnknownCallee(false, Change);
@@ -9521,7 +9521,7 @@ struct AACallEdgesFunction : public AACallEdgesImpl {
     ChangeStatus Change = ChangeStatus::UNCHANGED;
 
     auto ProcessCallInst = [&](Instruction &Inst) {
-      CallBase &CB = static_cast<CallBase &>(Inst);
+      CallBase &CB = cast<CallBase>(Inst);
 
       auto &CBEdges = A.getAAFor<AACallEdges>(
           *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
@@ -9552,11 +9552,39 @@ struct AACallEdgesFunction : public AACallEdgesImpl {
 struct AAFunctionReachabilityFunction : public AAFunctionReachability {
 private:
   struct QuerySet {
-    void markReachable(Function *Fn) {
-      Reachable.insert(Fn);
-      Unreachable.erase(Fn);
+    void markReachable(const Function &Fn) {
+      Reachable.insert(&Fn);
+      Unreachable.erase(&Fn);
     }
 
+    /// If there is no information about the function None is returned.
+    Optional<bool> isCachedReachable(const Function &Fn) {
+      // Assume that we can reach the function.
+      // TODO: Be more specific with the unknown callee.
+      if (CanReachUnknownCallee)
+        return true;
+
+      if (Reachable.count(&Fn))
+        return true;
+
+      if (Unreachable.count(&Fn))
+        return false;
+
+      return llvm::None;
+    }
+
+    /// Set of functions that we know for sure is reachable.
+    DenseSet<const Function *> Reachable;
+
+    /// Set of functions that are unreachable, but might become reachable.
+    DenseSet<const Function *> Unreachable;
+
+    /// If we can reach a function with a call to a unknown function we assume
+    /// that we can reach any function.
+    bool CanReachUnknownCallee = false;
+  };
+
+  struct QueryResolver : public QuerySet {
     ChangeStatus update(Attributor &A, const AAFunctionReachability &AA,
                         ArrayRef<const AACallEdges *> AAEdgesList) {
       ChangeStatus Change = ChangeStatus::UNCHANGED;
@@ -9570,31 +9598,25 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
         }
       }
 
-      for (Function *Fn : make_early_inc_range(Unreachable)) {
-        if (checkIfReachable(A, AA, AAEdgesList, Fn)) {
+      for (const Function *Fn : make_early_inc_range(Unreachable)) {
+        if (checkIfReachable(A, AA, AAEdgesList, *Fn)) {
           Change = ChangeStatus::CHANGED;
-          markReachable(Fn);
+          markReachable(*Fn);
         }
       }
       return Change;
     }
 
     bool isReachable(Attributor &A, const AAFunctionReachability &AA,
-                     ArrayRef<const AACallEdges *> AAEdgesList, Function *Fn) {
-      // Assume that we can reach the function.
-      // TODO: Be more specific with the unknown callee.
-      if (CanReachUnknownCallee)
-        return true;
-
-      if (Reachable.count(Fn))
-        return true;
-
-      if (Unreachable.count(Fn))
-        return false;
+                     ArrayRef<const AACallEdges *> AAEdgesList,
+                     const Function &Fn) {
+      Optional<bool> Cached = isCachedReachable(Fn);
+      if (Cached.hasValue())
+        return Cached.getValue();
 
       // We need to assume that this function can't reach Fn to prevent
       // an infinite loop if this function is recursive.
-      Unreachable.insert(Fn);
+      Unreachable.insert(&Fn);
 
       bool Result = checkIfReachable(A, AA, AAEdgesList, Fn);
       if (Result)
@@ -9604,13 +9626,13 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
 
     bool checkIfReachable(Attributor &A, const AAFunctionReachability &AA,
                           ArrayRef<const AACallEdges *> AAEdgesList,
-                          Function *Fn) const {
+                          const Function &Fn) const {
 
       // Handle the most trivial case first.
       for (auto *AAEdges : AAEdgesList) {
         const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges();
 
-        if (Edges.count(Fn))
+        if (Edges.count(const_cast<Function *>(&Fn)))
           return true;
       }
 
@@ -9631,28 +9653,80 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
       }
 
       // The result is false for now, set dependencies and leave.
-      for (auto Dep : Deps)
-        A.recordDependence(AA, *Dep, DepClassTy::REQUIRED);
+      for (auto *Dep : Deps)
+        A.recordDependence(*Dep, AA, DepClassTy::REQUIRED);
 
       return false;
     }
+  };
 
-    /// Set of functions that we know for sure is reachable.
-    DenseSet<Function *> Reachable;
+  /// Get call edges that can be reached by this instruction.
+  bool getReachableCallEdges(Attributor &A, const AAReachability &Reachability,
+                             const Instruction &Inst,
+                             SmallVector<const AACallEdges *> &Result) const {
+    // Determine call like instructions that we can reach from the inst.
+    auto CheckCallBase = [&](Instruction &CBInst) {
+      if (!Reachability.isAssumedReachable(A, Inst, CBInst))
+        return true;
 
-    /// Set of functions that are unreachable, but might become reachable.
-    DenseSet<Function *> Unreachable;
+      const auto &CB = cast<CallBase>(CBInst);
+      const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
+          *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
 
-    /// If we can reach a function with a call to a unknown function we assume
-    /// that we can reach any function.
-    bool CanReachUnknownCallee = false;
-  };
+      Result.push_back(&AAEdges);
+      return true;
+    };
+
+    bool UsedAssumedInformation = false;
+    return A.checkForAllCallLikeInstructions(CheckCallBase, *this,
+                                             UsedAssumedInformation);
+  }
+
+  ChangeStatus checkReachableBackwards(Attributor &A, QuerySet &Set) {
+    ChangeStatus Change = ChangeStatus::UNCHANGED;
+
+    // For all remaining instruction queries, check
+    // callers. A call inside that function might satisfy the query.
+    auto CheckCallSite = [&](AbstractCallSite CallSite) {
+      CallBase *CB = CallSite.getInstruction();
+      if (!CB)
+        return false;
+
+      if (isa<InvokeInst>(CB))
+        return false;
+
+      Instruction *Inst = CB->getNextNonDebugInstruction();
+      const AAFunctionReachability &AA = A.getAAFor<AAFunctionReachability>(
+          *this, IRPosition::function(*Inst->getFunction()),
+          DepClassTy::REQUIRED);
+      for (const Function *Fn : make_early_inc_range(Set.Unreachable)) {
+        if (AA.instructionCanReach(A, *Inst, *Fn, /* UseBackwards */ false)) {
+          Set.markReachable(*Fn);
+          Change = ChangeStatus::CHANGED;
+        }
+      }
+      return true;
+    };
+
+    bool NoUnknownCall = true;
+    if (A.checkForAllCallSites(CheckCallSite, *this, true, NoUnknownCall))
+      return Change;
+
+    // If we don't know all callsites we have to assume that we can reach fn.
+    for (auto &QSet : InstQueriesBackwards) {
+      if (!QSet.second.CanReachUnknownCallee)
+        Change = ChangeStatus::CHANGED;
+      QSet.second.CanReachUnknownCallee = true;
+    }
+
+    return Change;
+  }
 
 public:
   AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A)
       : AAFunctionReachability(IRP, A) {}
 
-  bool canReach(Attributor &A, Function *Fn) const override {
+  bool canReach(Attributor &A, const Function &Fn) const override {
     const AACallEdges &AAEdges =
         A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED);
 
@@ -9668,7 +9742,8 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
   }
 
   /// Can \p CB reach \p Fn
-  bool canReach(Attributor &A, CallBase &CB, Function *Fn) const override {
+  bool canReach(Attributor &A, CallBase &CB,
+                const Function &Fn) const override {
     const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
         *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
 
@@ -9677,13 +9752,52 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
     // a const_cast.
     // This is a hack for us to be able to cache queries.
     auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
-    QuerySet &CBQuery = NonConstThis->CBQueries[&CB];
+    QueryResolver &CBQuery = NonConstThis->CBQueries[&CB];
 
     bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, Fn);
 
     return Result;
   }
 
+  bool instructionCanReach(Attributor &A, const Instruction &Inst,
+                           const Function &Fn,
+                           bool UseBackwards) const override {
+    const auto &Reachability = &A.getAAFor<AAReachability>(
+        *this, IRPosition::function(*getAssociatedFunction()),
+        DepClassTy::REQUIRED);
+
+    SmallVector<const AACallEdges *> CallEdges;
+    bool AllKnown = getReachableCallEdges(A, *Reachability, Inst, CallEdges);
+    // Attributor returns attributes as const, so this function has to be
+    // const for users of this attribute to use it without having to do
+    // a const_cast.
+    // This is a hack for us to be able to cache queries.
+    auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
+    QueryResolver &InstQSet = NonConstThis->InstQueries[&Inst];
+    if (!AllKnown)
+      InstQSet.CanReachUnknownCallee = true;
+
+    bool ForwardsResult = InstQSet.isReachable(A, *this, CallEdges, Fn);
+    if (ForwardsResult)
+      return true;
+    // We are done.
+    if (!UseBackwards)
+      return false;
+
+    QuerySet &InstBackwardsQSet = NonConstThis->InstQueriesBackwards[&Inst];
+
+    Optional<bool> BackwardsCached = InstBackwardsQSet.isCachedReachable(Fn);
+    if (BackwardsCached.hasValue())
+      return BackwardsCached.getValue();
+
+    // Assume unreachable, to prevent problems.
+    InstBackwardsQSet.Unreachable.insert(&Fn);
+
+    // Check backwards reachability.
+    NonConstThis->checkReachableBackwards(A, InstBackwardsQSet);
+    return InstBackwardsQSet.isCachedReachable(Fn).getValue();
+  }
+
   /// See AbstractAttribute::updateImpl(...).
   ChangeStatus updateImpl(Attributor &A) override {
     const AACallEdges &AAEdges =
@@ -9692,7 +9806,7 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
 
     Change |= WholeFunction.update(A, *this, {&AAEdges});
 
-    for (auto CBPair : CBQueries) {
+    for (auto &CBPair : CBQueries) {
       const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
           *this, IRPosition::callsite_function(*CBPair.first),
           DepClassTy::REQUIRED);
@@ -9700,6 +9814,29 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
       Change |= CBPair.second.update(A, *this, {&AAEdges});
     }
 
+    // Update the Instruction queries.
+    const AAReachability *Reachability;
+    if (!InstQueries.empty()) {
+      Reachability = &A.getAAFor<AAReachability>(
+          *this, IRPosition::function(*getAssociatedFunction()),
+          DepClassTy::REQUIRED);
+    }
+
+    // Check for local callbases first.
+    for (auto &InstPair : InstQueries) {
+      SmallVector<const AACallEdges *> CallEdges;
+      bool AllKnown =
+          getReachableCallEdges(A, *Reachability, *InstPair.first, CallEdges);
+      // Update will return change if we this effects any queries.
+      if (!AllKnown)
+        InstPair.second.CanReachUnknownCallee = true;
+      Change |= InstPair.second.update(A, *this, CallEdges);
+    }
+
+    // Update backwards queries.
+    for (auto &QueryPair : InstQueriesBackwards)
+      Change |= checkReachableBackwards(A, QueryPair.second);
+
     return Change;
   }
 
@@ -9720,11 +9857,17 @@ struct AAFunctionReachabilityFunction : public AAFunctionReachability {
   }
 
   /// Used to answer if a the whole function can reacha a specific function.
-  QuerySet WholeFunction;
+  QueryResolver WholeFunction;
 
   /// Used to answer if a call base inside this function can reach a specific
   /// function.
-  DenseMap<CallBase *, QuerySet> CBQueries;
+  DenseMap<const CallBase *, QueryResolver> CBQueries;
+
+  /// This is for instruction queries than scan "forward".
+  DenseMap<const Instruction *, QueryResolver> InstQueries;
+
+  /// This is for instruction queries than scan "backward".
+  DenseMap<const Instruction *, QuerySet> InstQueriesBackwards;
 };
 
 /// ---------------------- Assumption Propagation ------------------------------

diff  --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp
index e1d2709e883e1..ea4f1ba70f4c6 100644
--- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp
+++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp
@@ -75,47 +75,69 @@ TEST_F(AttributorTestBase, TestCast) {
 
 TEST_F(AttributorTestBase, AAReachabilityTest) {
   const char *ModuleString = R"(
-    @x = global i32 0
-    define void @func4() {
+    @x = external global i32
+    define internal void @func4() {
       store i32 0, i32* @x
       ret void
     }
 
-    define void @func3() {
+    define internal void @func3() {
       store i32 0, i32* @x
       ret void
     }
 
-    define void @func2() {
+    define internal void @func8() {
+      store i32 0, i32* @x
+      ret void
+    }
+
+    define internal void @func2() {
     entry:
       call void @func3()
       ret void
     }
 
-    define void @func1() {
+    define internal void @func1() {
     entry:
       call void @func2()
       ret void
     }
 
-    define void @func5(void ()* %unknown) {
+    declare void @unknown()
+    define internal void @func5(void ()* %ptr) {
     entry:
-      call void %unknown()
+      call void %ptr()
+      call void @unknown()
       ret void
     }
 
-    define void @func6() {
+    define internal void @func6() {
     entry:
       call void @func5(void ()* @func3)
       ret void
     }
 
-    define void @func7() {
+    define internal void @func7() {
+    entry:
+      call void @func2()
+      call void @func4()
+      ret void
+    }
+
+    define internal void @func9() {
     entry:
       call void @func2()
+      call void @func8()
+      ret void
+    }
+
+    define void @func10() {
+    entry:
+      call void @func9()
       call void @func4()
       ret void
     }
+
   )";
 
   Module &M = parseModule(ModuleString);
@@ -128,32 +150,43 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
   CallGraphUpdater CGUpdater;
   BumpPtrAllocator Allocator;
   InformationCache InfoCache(M, AG, Allocator, nullptr);
-  Attributor A(Functions, InfoCache, CGUpdater);
+  Attributor A(Functions, InfoCache, CGUpdater, /* Allowed */ nullptr,
+               /*DeleteFns*/ false);
 
-  Function *F1 = M.getFunction("func1");
-  Function *F3 = M.getFunction("func3");
-  Function *F4 = M.getFunction("func4");
-  Function *F6 = M.getFunction("func6");
-  Function *F7 = M.getFunction("func7");
+  Function &F1 = *M.getFunction("func1");
+  Function &F3 = *M.getFunction("func3");
+  Function &F4 = *M.getFunction("func4");
+  Function &F6 = *M.getFunction("func6");
+  Function &F7 = *M.getFunction("func7");
+  Function &F9 = *M.getFunction("func9");
 
   // call void @func2()
-  CallBase &F7FirstCB =
-      *static_cast<CallBase *>(F7->getEntryBlock().getFirstNonPHI());
+  CallBase &F7FirstCB = static_cast<CallBase &>(*F7.getEntryBlock().begin());
+  // call void @func2()
+  Instruction &F9FirstInst = *F9.getEntryBlock().begin();
+  // call void @func8
+  Instruction &F9SecondInst = *++(F9.getEntryBlock().begin());
 
   const AAFunctionReachability &F1AA =
-      A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F1));
+      A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(F1));
 
   const AAFunctionReachability &F6AA =
-      A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F6));
+      A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(F6));
 
   const AAFunctionReachability &F7AA =
-      A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F7));
+      A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(F7));
+
+  const AAFunctionReachability &F9AA =
+      A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(F9));
 
   F1AA.canReach(A, F3);
   F1AA.canReach(A, F4);
   F6AA.canReach(A, F4);
   F7AA.canReach(A, F7FirstCB, F3);
   F7AA.canReach(A, F7FirstCB, F4);
+  F9AA.instructionCanReach(A, F9FirstInst, F3);
+  F9AA.instructionCanReach(A, F9SecondInst, F3, false);
+  F9AA.instructionCanReach(A, F9FirstInst, F4);
 
   A.run();
 
@@ -166,6 +199,15 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
   // Assumed to be reacahable, since F6 can reach a function with
   // a unknown callee.
   ASSERT_TRUE(F6AA.canReach(A, F4));
+
+  // The second instruction of F9 can't reach the first call.
+  ASSERT_FALSE(F9AA.instructionCanReach(A, F9SecondInst, F3, false));
+  ASSERT_FALSE(F9AA.instructionCanReach(A, F9SecondInst, F3, true));
+
+  // The first instruction of F9 can reach the first call.
+  ASSERT_TRUE(F9AA.instructionCanReach(A, F9FirstInst, F3));
+  // Because func10 calls the func4 after the call to func9 it is reachable.
+  ASSERT_TRUE(F9AA.instructionCanReach(A, F9FirstInst, F4));
 }
 
 } // namespace llvm


        


More information about the llvm-commits mailing list