[llvm] 66a0b34 - [Attributor] AAFunctionReachability, Handle CallBase Reachability.

Kuter Dinel via llvm-commits llvm-commits at lists.llvm.org
Sun Sep 12 15:35:59 PDT 2021


Author: Kuter Dinel
Date: 2021-09-13T01:35:44+03:00
New Revision: 66a0b3464ca8502a6a3b59800f9b31fcd7aa6e97

URL: https://github.com/llvm/llvm-project/commit/66a0b3464ca8502a6a3b59800f9b31fcd7aa6e97
DIFF: https://github.com/llvm/llvm-project/commit/66a0b3464ca8502a6a3b59800f9b31fcd7aa6e97.diff

LOG: [Attributor] AAFunctionReachability, Handle CallBase Reachability.

This patch makes it possible to query callbase reachability
(Can a callbase reach a function Fn transitively).
The patch moves the reachability query handling logic to a member class,
this class will have more users within the AA once we add other function
reachability queries.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D106402

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/IPO/Attributor.h
    llvm/lib/Transforms/IPO/AttributorAttributes.cpp
    llvm/unittests/Transforms/IPO/AttributorTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h
index 97458a988e590..d0ab2f6e3755c 100644
--- a/llvm/include/llvm/Transforms/IPO/Attributor.h
+++ b/llvm/include/llvm/Transforms/IPO/Attributor.h
@@ -4454,6 +4454,9 @@ struct AAFunctionReachability
   /// If the function represented by this possition can reach \p Fn.
   virtual bool canReach(Attributor &A, Function *Fn) const = 0;
 
+  /// Can \p CB reach \p Fn
+  virtual bool canReach(Attributor &A, CallBase &CB, Function *Fn) const = 0;
+
   /// Create an abstract attribute view for the position \p IRP.
   static AAFunctionReachability &createForPosition(const IRPosition &IRP,
                                                    Attributor &A);

diff  --git a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
index 02f1d38f18984..0f73d1925f421 100644
--- a/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
+++ b/llvm/lib/Transforms/IPO/AttributorAttributes.cpp
@@ -9496,115 +9496,180 @@ struct AACallEdgesFunction : public AACallEdgesImpl {
 };
 
 struct AAFunctionReachabilityFunction : public AAFunctionReachability {
-  AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A)
-      : AAFunctionReachability(IRP, A) {}
+private:
+  struct QuerySet {
+    void markReachable(Function *Fn) {
+      Reachable.insert(Fn);
+      Unreachable.erase(Fn);
+    }
+
+    ChangeStatus update(Attributor &A, const AAFunctionReachability &AA,
+                        ArrayRef<const AACallEdges *> AAEdgesList) {
+      ChangeStatus Change = ChangeStatus::UNCHANGED;
+
+      for (auto *AAEdges : AAEdgesList) {
+        if (AAEdges->hasUnknownCallee()) {
+          if (!CanReachUnknownCallee)
+            Change = ChangeStatus::CHANGED;
+          CanReachUnknownCallee = true;
+          return Change;
+        }
+      }
 
-  bool canReach(Attributor &A, Function *Fn) const override {
-    // Assume that we can reach any function if we can reach a call with
-    // unknown callee.
-    if (CanReachUnknownCallee)
-      return true;
+      for (Function *Fn : make_early_inc_range(Unreachable)) {
+        if (checkIfReachable(A, AA, AAEdgesList, Fn)) {
+          Change = ChangeStatus::CHANGED;
+          markReachable(Fn);
+        }
+      }
+      return Change;
+    }
 
-    if (ReachableQueries.count(Fn))
-      return true;
+    bool isReachable(Attributor &A, const AAFunctionReachability &AA,
+                     ArrayRef<const AACallEdges *> AAEdgesList, Function *Fn) {
+      // Assume that we can reach the function.
+      // TODO: Be more specific with the unknown callee.
+      if (CanReachUnknownCallee)
+        return true;
+
+      if (Reachable.count(Fn))
+        return true;
+
+      if (Unreachable.count(Fn))
+        return false;
+
+      // We need to assume that this function can't reach Fn to prevent
+      // an infinite loop if this function is recursive.
+      Unreachable.insert(Fn);
+
+      bool Result = checkIfReachable(A, AA, AAEdgesList, Fn);
+      if (Result)
+        markReachable(Fn);
+      return Result;
+    }
+
+    bool checkIfReachable(Attributor &A, const AAFunctionReachability &AA,
+                          ArrayRef<const AACallEdges *> AAEdgesList,
+                          Function *Fn) const {
+
+      // Handle the most trivial case first.
+      for (auto *AAEdges : AAEdgesList) {
+        const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges();
+
+        if (Edges.count(Fn))
+          return true;
+      }
+
+      SmallVector<const AAFunctionReachability *, 8> Deps;
+      for (auto &AAEdges : AAEdgesList) {
+        const SetVector<Function *> &Edges = AAEdges->getOptimisticEdges();
+
+        for (Function *Edge : Edges) {
+          // We don't need a dependency if the result is reachable.
+          const AAFunctionReachability &EdgeReachability =
+              A.getAAFor<AAFunctionReachability>(
+                  AA, IRPosition::function(*Edge), DepClassTy::NONE);
+          Deps.push_back(&EdgeReachability);
+
+          if (EdgeReachability.canReach(A, Fn))
+            return true;
+        }
+      }
+
+      // The result is false for now, set dependencies and leave.
+      for (auto Dep : Deps)
+        A.recordDependence(AA, *Dep, DepClassTy::REQUIRED);
 
-    if (UnreachableQueries.count(Fn))
       return false;
+    }
+
+    /// Set of functions that we know for sure is reachable.
+    DenseSet<Function *> Reachable;
+
+    /// Set of functions that are unreachable, but might become reachable.
+    DenseSet<Function *> Unreachable;
 
+    /// If we can reach a function with a call to a unknown function we assume
+    /// that we can reach any function.
+    bool CanReachUnknownCallee = false;
+  };
+
+public:
+  AAFunctionReachabilityFunction(const IRPosition &IRP, Attributor &A)
+      : AAFunctionReachability(IRP, A) {}
+
+  bool canReach(Attributor &A, Function *Fn) const override {
     const AACallEdges &AAEdges =
         A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED);
 
-    const SetVector<Function *> &Edges = AAEdges.getOptimisticEdges();
-    bool Result = checkIfReachable(A, Edges, Fn);
+    // Attributor returns attributes as const, so this function has to be
+    // const for users of this attribute to use it without having to do
+    // a const_cast.
+    // This is a hack for us to be able to cache queries.
+    auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
+    bool Result =
+        NonConstThis->WholeFunction.isReachable(A, *this, {&AAEdges}, Fn);
+
+    return Result;
+  }
+
+  /// Can \p CB reach \p Fn
+  bool canReach(Attributor &A, CallBase &CB, Function *Fn) const override {
+    const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
+        *this, IRPosition::callsite_function(CB), DepClassTy::REQUIRED);
 
     // Attributor returns attributes as const, so this function has to be
     // const for users of this attribute to use it without having to do
     // a const_cast.
     // This is a hack for us to be able to cache queries.
     auto *NonConstThis = const_cast<AAFunctionReachabilityFunction *>(this);
+    QuerySet &CBQuery = NonConstThis->CBQueries[&CB];
 
-    if (Result)
-      NonConstThis->ReachableQueries.insert(Fn);
-    else
-      NonConstThis->UnreachableQueries.insert(Fn);
+    bool Result = CBQuery.isReachable(A, *this, {&AAEdges}, Fn);
 
     return Result;
   }
 
   /// See AbstractAttribute::updateImpl(...).
   ChangeStatus updateImpl(Attributor &A) override {
-    if (CanReachUnknownCallee)
-      return ChangeStatus::UNCHANGED;
-
     const AACallEdges &AAEdges =
         A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::REQUIRED);
-    const SetVector<Function *> &Edges = AAEdges.getOptimisticEdges();
     ChangeStatus Change = ChangeStatus::UNCHANGED;
 
-    if (AAEdges.hasUnknownCallee()) {
-      bool OldCanReachUnknown = CanReachUnknownCallee;
-      CanReachUnknownCallee = true;
-      return OldCanReachUnknown ? ChangeStatus::UNCHANGED
-                                : ChangeStatus::CHANGED;
-    }
+    Change |= WholeFunction.update(A, *this, {&AAEdges});
 
-    // Check if any of the unreachable functions become reachable.
-    for (auto Current = UnreachableQueries.begin();
-         Current != UnreachableQueries.end();) {
-      if (!checkIfReachable(A, Edges, *Current)) {
-        Current++;
-        continue;
-      }
-      ReachableQueries.insert(*Current);
-      UnreachableQueries.erase(*Current++);
-      Change = ChangeStatus::CHANGED;
+    for (auto CBPair : CBQueries) {
+      const AACallEdges &AAEdges = A.getAAFor<AACallEdges>(
+          *this, IRPosition::callsite_function(*CBPair.first),
+          DepClassTy::REQUIRED);
+
+      Change |= CBPair.second.update(A, *this, {&AAEdges});
     }
 
     return Change;
   }
 
   const std::string getAsStr() const override {
-    size_t QueryCount = ReachableQueries.size() + UnreachableQueries.size();
+    size_t QueryCount =
+        WholeFunction.Reachable.size() + WholeFunction.Unreachable.size();
 
-    return "FunctionReachability [" + std::to_string(ReachableQueries.size()) +
-           "," + std::to_string(QueryCount) + "]";
+    return "FunctionReachability [" +
+           std::to_string(WholeFunction.Reachable.size()) + "," +
+           std::to_string(QueryCount) + "]";
   }
 
   void trackStatistics() const override {}
-
 private:
-  bool canReachUnknownCallee() const override { return CanReachUnknownCallee; }
-
-  bool checkIfReachable(Attributor &A, const SetVector<Function *> &Edges,
-                        Function *Fn) const {
-    if (Edges.count(Fn))
-      return true;
-
-    for (Function *Edge : Edges) {
-      // We don't need a dependency if the result is reachable.
-      const AAFunctionReachability &EdgeReachability =
-          A.getAAFor<AAFunctionReachability>(*this, IRPosition::function(*Edge),
-                                             DepClassTy::NONE);
-
-      if (EdgeReachability.canReach(A, Fn))
-        return true;
-    }
-    for (Function *Fn : Edges)
-      A.getAAFor<AAFunctionReachability>(*this, IRPosition::function(*Fn),
-                                         DepClassTy::REQUIRED);
-
-    return false;
+  bool canReachUnknownCallee() const override {
+    return WholeFunction.CanReachUnknownCallee;
   }
 
-  /// Set of functions that we know for sure is reachable.
-  SmallPtrSet<Function *, 8> ReachableQueries;
-
-  /// Set of functions that are unreachable, but might become reachable.
-  SmallPtrSet<Function *, 8> UnreachableQueries;
+  /// Used to answer if a the whole function can reacha a specific function.
+  QuerySet WholeFunction;
 
-  /// If we can reach a function with a call to a unknown function we assume
-  /// that we can reach any function.
-  bool CanReachUnknownCallee = false;
+  /// Used to answer if a call base inside this function can reach a specific
+  /// function.
+  DenseMap<CallBase *, QuerySet> CBQueries;
 };
 
 } // namespace

diff  --git a/llvm/unittests/Transforms/IPO/AttributorTest.cpp b/llvm/unittests/Transforms/IPO/AttributorTest.cpp
index 51de06c6d5e17..e1d2709e883e1 100644
--- a/llvm/unittests/Transforms/IPO/AttributorTest.cpp
+++ b/llvm/unittests/Transforms/IPO/AttributorTest.cpp
@@ -109,6 +109,13 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
       call void @func5(void ()* @func3)
       ret void
     }
+
+    define void @func7() {
+    entry:
+      call void @func2()
+      call void @func4()
+      ret void
+    }
   )";
 
   Module &M = parseModule(ModuleString);
@@ -127,6 +134,11 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
   Function *F3 = M.getFunction("func3");
   Function *F4 = M.getFunction("func4");
   Function *F6 = M.getFunction("func6");
+  Function *F7 = M.getFunction("func7");
+
+  // call void @func2()
+  CallBase &F7FirstCB =
+      *static_cast<CallBase *>(F7->getEntryBlock().getFirstNonPHI());
 
   const AAFunctionReachability &F1AA =
       A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F1));
@@ -134,15 +146,23 @@ TEST_F(AttributorTestBase, AAReachabilityTest) {
   const AAFunctionReachability &F6AA =
       A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F6));
 
+  const AAFunctionReachability &F7AA =
+      A.getOrCreateAAFor<AAFunctionReachability>(IRPosition::function(*F7));
+
   F1AA.canReach(A, F3);
   F1AA.canReach(A, F4);
   F6AA.canReach(A, F4);
+  F7AA.canReach(A, F7FirstCB, F3);
+  F7AA.canReach(A, F7FirstCB, F4);
 
   A.run();
 
   ASSERT_TRUE(F1AA.canReach(A, F3));
   ASSERT_FALSE(F1AA.canReach(A, F4));
 
+  ASSERT_TRUE(F7AA.canReach(A, F7FirstCB, F3));
+  ASSERT_FALSE(F7AA.canReach(A, F7FirstCB, F4));
+
   // Assumed to be reacahable, since F6 can reach a function with
   // a unknown callee.
   ASSERT_TRUE(F6AA.canReach(A, F4));


        


More information about the llvm-commits mailing list