[llvm] 8ea858b - [CallPromotionUtil] See through function alias when devirtualizing a virtual call on an alloca. (#80736)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 6 09:22:39 PST 2024


Author: Mingming Liu
Date: 2024-02-06T09:22:34-08:00
New Revision: 8ea858b96787578e814723a009f443808f446378

URL: https://github.com/llvm/llvm-project/commit/8ea858b96787578e814723a009f443808f446378
DIFF: https://github.com/llvm/llvm-project/commit/8ea858b96787578e814723a009f443808f446378.diff

LOG: [CallPromotionUtil] See through function alias when devirtualizing a virtual call on an alloca. (#80736)

- Extract utility function from
`DevirtModule::tryFindVirtualCallTargets` [1], which sees through an alias to a function. Call this utility function in
the WPD callsite.
- For type profiling work, this helper function will be used by indirect-call-promotion pass to find the function pointer at a specified vtable offset (an example in [2])

[1] https://github.com/llvm/llvm-project/blob/b99163fe8feeacba7797d5479bbcd5d8f327dd2d/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp#L1069-L1082
[2] https://github.com/minglotus-6/llvm-project/blob/77a0ef12de82d11f448f7f9de6f2dcf87d9b74af/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp#L347

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TypeMetadataUtils.h
    llvm/lib/Analysis/TypeMetadataUtils.cpp
    llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
    llvm/lib/Transforms/Utils/CallPromotionUtils.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TypeMetadataUtils.h b/llvm/include/llvm/Analysis/TypeMetadataUtils.h
index dab67aad1ab0e..8894945c28d94 100644
--- a/llvm/include/llvm/Analysis/TypeMetadataUtils.h
+++ b/llvm/include/llvm/Analysis/TypeMetadataUtils.h
@@ -15,6 +15,7 @@
 #define LLVM_ANALYSIS_TYPEMETADATAUTILS_H
 
 #include <cstdint>
+#include <utility>
 
 namespace llvm {
 
@@ -24,6 +25,7 @@ class CallInst;
 class Constant;
 class Function;
 class DominatorTree;
+class GlobalVariable;
 class Instruction;
 class Module;
 
@@ -77,6 +79,13 @@ void findDevirtualizableCallsForTypeCheckedLoad(
 Constant *getPointerAtOffset(Constant *I, uint64_t Offset, Module &M,
                              Constant *TopLevelGlobal = nullptr);
 
+/// Given a vtable and a specified offset, returns the function and the trivial
+/// pointer at the specified offset in pair iff the pointer at the specified
+/// offset is a function or an alias to a function. Returns a pair of nullptr
+/// otherwise.
+std::pair<Function *, Constant *>
+getFunctionAtVTableOffset(GlobalVariable *GV, uint64_t Offset, Module &M);
+
 /// Finds the same "relative pointer" pattern as described above, where the
 /// target is `F`, and replaces the entire pattern with a constant zero.
 void replaceRelativePointerUsersWithZero(Function *F);

diff  --git a/llvm/lib/Analysis/TypeMetadataUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp
index bbaee06ed8a55..b8dcc39e9223c 100644
--- a/llvm/lib/Analysis/TypeMetadataUtils.cpp
+++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp
@@ -201,6 +201,26 @@ Constant *llvm::getPointerAtOffset(Constant *I, uint64_t Offset, Module &M,
   return nullptr;
 }
 
+std::pair<Function *, Constant *>
+llvm::getFunctionAtVTableOffset(GlobalVariable *GV, uint64_t Offset,
+                                Module &M) {
+  Constant *Ptr = getPointerAtOffset(GV->getInitializer(), Offset, M, GV);
+  if (!Ptr)
+    return std::pair<Function *, Constant *>(nullptr, nullptr);
+
+  auto C = Ptr->stripPointerCasts();
+  // Make sure this is a function or alias to a function.
+  auto Fn = dyn_cast<Function>(C);
+  auto A = dyn_cast<GlobalAlias>(C);
+  if (!Fn && A)
+    Fn = dyn_cast<Function>(A->getAliasee());
+
+  if (!Fn)
+    return std::pair<Function *, Constant *>(nullptr, nullptr);
+
+  return std::pair<Function *, Constant *>(Fn, C);
+}
+
 void llvm::replaceRelativePointerUsersWithZero(Function *F) {
   for (auto *U : F->users()) {
     auto *PtrExpr = dyn_cast<ConstantExpr>(U);

diff  --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 01aba47cdbfff..75f7de4290a74 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -1066,17 +1066,10 @@ bool DevirtModule::tryFindVirtualCallTargets(
         GlobalObject::VCallVisibilityPublic)
       return false;
 
-    Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(),
-                                       TM.Offset + ByteOffset, M, TM.Bits->GV);
-    if (!Ptr)
-      return false;
-
-    auto C = Ptr->stripPointerCasts();
-    // Make sure this is a function or alias to a function.
-    auto Fn = dyn_cast<Function>(C);
-    auto A = dyn_cast<GlobalAlias>(C);
-    if (!Fn && A)
-      Fn = dyn_cast<Function>(A->getAliasee());
+    Function *Fn = nullptr;
+    Constant *C = nullptr;
+    std::tie(Fn, C) =
+        getFunctionAtVTableOffset(TM.Bits->GV, TM.Offset + ByteOffset, M);
 
     if (!Fn)
       return false;

diff  --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index e42cdab64446e..4e84927f1cfc9 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -597,16 +597,13 @@ bool llvm::tryPromoteCall(CallBase &CB) {
     // Not in the form of a global constant variable with an initializer.
     return false;
 
-  Constant *VTableGVInitializer = GV->getInitializer();
   APInt VTableGVOffset = VTableOffsetGVBase + VTableOffset;
   if (!(VTableGVOffset.getActiveBits() <= 64))
     return false; // Out of range.
-  Constant *Ptr = getPointerAtOffset(VTableGVInitializer,
-                                     VTableGVOffset.getZExtValue(),
-                                     *M);
-  if (!Ptr)
-    return false; // No constant (function) pointer found.
-  Function *DirectCallee = dyn_cast<Function>(Ptr->stripPointerCasts());
+
+  Function *DirectCallee = nullptr;
+  std::tie(DirectCallee, std::ignore) =
+      getFunctionAtVTableOffset(GV, VTableGVOffset.getZExtValue(), *M);
   if (!DirectCallee)
     return false; // No function pointer found.
 


        


More information about the llvm-commits mailing list