[clang] [compiler-rt] [llvm] [TypeProf][InstrFDO]Implement more efficient comparison sequence for indirect-call-promotion with vtable profiles. (PR #81442)

David Li via cfe-commits cfe-commits at lists.llvm.org
Wed Jun 12 13:55:03 PDT 2024


================
@@ -321,14 +746,127 @@ bool IndirectCallPromoter::processFunction(ProfileSummaryInfo *PSI) {
     if (!NumCandidates ||
         (PSI && PSI->hasProfileSummary() && !PSI->isHotCount(TotalCount)))
       continue;
+
     auto PromotionCandidates = getPromotionCandidatesForCallSite(
         *CB, ICallProfDataRef, TotalCount, NumCandidates);
-    Changed |= tryToPromoteWithFuncCmp(*CB, PromotionCandidates, TotalCount,
-                                       ICallProfDataRef, NumCandidates);
+
+    VTableGUIDCountsMap VTableGUIDCounts;
+    Instruction *VPtr =
+        computeVTableInfos(CB, VTableGUIDCounts, PromotionCandidates);
+
+    if (isProfitableToCompareVTables(PromotionCandidates, TotalCount))
+      Changed |= tryToPromoteWithVTableCmp(*CB, VPtr, PromotionCandidates,
+                                           TotalCount, NumCandidates,
+                                           ICallProfDataRef, VTableGUIDCounts);
+    else
+      Changed |= tryToPromoteWithFuncCmp(*CB, VPtr, PromotionCandidates,
+                                         TotalCount, ICallProfDataRef,
+                                         NumCandidates, VTableGUIDCounts);
   }
   return Changed;
 }
 
+// TODO: Returns false if the function addressing and vtable load instructions
+// cannot sink to indirect fallback.
+bool IndirectCallPromoter::isProfitableToCompareVTables(
+    const std::vector<PromotionCandidate> &Candidates, uint64_t TotalCount) {
+  if (!ICPEnableVTableCmp || Candidates.empty())
+    return false;
+  uint64_t RemainingVTableCount = TotalCount;
+  for (size_t I = 0; I < Candidates.size(); I++) {
+    auto &Candidate = Candidates[I];
+    uint64_t VTableSumCount = 0;
+    for (auto &[GUID, Count] : Candidate.VTableGUIDAndCounts)
+      VTableSumCount += Count;
+
+    if (VTableSumCount < Candidate.Count * ICPVTableCountPercentage)
+      return false;
+
+    RemainingVTableCount -= Candidate.Count;
+
+    int NumAdditionalVTable = 0;
+    if (I == Candidates.size() - 1)
+      NumAdditionalVTable = ICPNumAdditionalVTableLast;
+
+    int ActualNumAdditionalInst = Candidate.AddressPoints.size() - 1;
+    if (ActualNumAdditionalInst > NumAdditionalVTable) {
+      return false;
+    }
+  }
+
+  // If the indirect fallback is not cold, don't compare vtables.
+  if (PSI && PSI->hasProfileSummary() &&
+      !PSI->isColdCount(RemainingVTableCount))
+    return false;
+
+  return true;
+}
+
+static void
+computeVirtualCallSiteTypeInfoMap(Module &M, ModuleAnalysisManager &MAM,
+                                  VirtualCallSiteTypeInfoMap &VirtualCSInfo) {
+  auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+  auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & {
+    return FAM.getResult<DominatorTreeAnalysis>(F);
+  };
+
+  auto compute = [&](Function *Func) {
+    if (!Func || Func->use_empty())
+      return;
+    // Iterate all type.test calls and find all indirect calls.
+    // TODO: Add llvm.public.type.test
+    for (Use &U : llvm::make_early_inc_range(Func->uses())) {
+      auto *CI = dyn_cast<CallInst>(U.getUser());
+      if (!CI)
+        continue;
+      auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(1));
+      if (!TypeMDVal)
+        continue;
+      auto *CompatibleTypeId = dyn_cast<MDString>(TypeMDVal->getMetadata());
+      if (!CompatibleTypeId)
+        continue;
+
+      // Find out all devirtualizable call sites given a llvm.type.test
+      // intrinsic call.
+      SmallVector<DevirtCallSite, 1> DevirtCalls;
+      SmallVector<CallInst *, 1> Assumes;
+      auto &DT = LookupDomTree(*CI->getFunction());
+      findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT);
+
+      // type-id, offset from the address point
+      // combined with type metadata to compute function offset
+      for (auto &DevirtCall : DevirtCalls) {
+        CallBase &CB = DevirtCall.CB;
+        // Given an indirect call, try find the instruction which loads a
+        // pointer to virtual table.
+        Instruction *VTablePtr =
+            PGOIndirectCallVisitor::tryGetVTableInstruction(&CB);
+        if (!VTablePtr)
+          continue;
+        VirtualCSInfo[&CB] = {DevirtCall.Offset, VTablePtr,
+                              CompatibleTypeId->getString()};
+      }
+    }
+  };
+
+  // Right now only llvm.type.test is used to find out virtual call sites.
+  // With ThinLTO and whole-program-devirtualization, llvm.type.test and
+  // llvm.public.type.test are emitted, and llvm.public.type.test is either
+  // refined to llvm.type.test or dropped before indirect-call-promotion pass.
----------------
david-xl wrote:

ok

https://github.com/llvm/llvm-project/pull/81442


More information about the cfe-commits mailing list