[llvm] cea6f4d - [llvm][NFC][CallSite] Remove CallSite from TypeMetadataUtils & related

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 23 08:23:36 PDT 2020


Author: Mircea Trofin
Date: 2020-04-23T08:23:16-07:00
New Revision: cea6f4d5f8431ea723e85f6e57812feb3633ecbe

URL: https://github.com/llvm/llvm-project/commit/cea6f4d5f8431ea723e85f6e57812feb3633ecbe
DIFF: https://github.com/llvm/llvm-project/commit/cea6f4d5f8431ea723e85f6e57812feb3633ecbe.diff

LOG: [llvm][NFC][CallSite] Remove CallSite from TypeMetadataUtils & related

Reviewers: craig.topper, dblaikie

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D78666

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/TypeMetadataUtils.h
    llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
    llvm/lib/Analysis/TypeMetadataUtils.cpp
    llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/TypeMetadataUtils.h b/llvm/include/llvm/Analysis/TypeMetadataUtils.h
index 43ce26147c2e..ea59c0cbc871 100644
--- a/llvm/include/llvm/Analysis/TypeMetadataUtils.h
+++ b/llvm/include/llvm/Analysis/TypeMetadataUtils.h
@@ -15,7 +15,7 @@
 #define LLVM_ANALYSIS_TYPEMETADATAUTILS_H
 
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Instructions.h"
 
 namespace llvm {
 
@@ -33,7 +33,7 @@ struct DevirtCallSite {
   /// The offset from the address point to the virtual function.
   uint64_t Offset;
   /// The call site itself.
-  CallSite CS;
+  CallBase &CB;
 };
 
 /// Given a call to the intrinsic \@llvm.type.test, find all devirtualizable

diff  --git a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
index 1ff47e10bd99..b78115bd43f7 100644
--- a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
+++ b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
@@ -28,7 +28,6 @@
 #include "llvm/Analysis/TypeMetadataUtils.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CallSite.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Dominators.h"
@@ -99,7 +98,7 @@ static bool findRefEdges(ModuleSummaryIndex &Index, const User *CurUser,
     if (!Visited.insert(U).second)
       continue;
 
-    ImmutableCallSite CS(U);
+    const auto *CB = dyn_cast<CallBase>(U);
 
     for (const auto &OI : U->operands()) {
       const User *Operand = dyn_cast<User>(OI);
@@ -113,7 +112,7 @@ static bool findRefEdges(ModuleSummaryIndex &Index, const User *CurUser,
         // We have a reference to a global value. This should be added to
         // the reference set unless it is a callee. Callees are handled
         // specially by WriteFunction and are added to a separate list.
-        if (!(CS && CS.isCallee(&OI)))
+        if (!(CB && CB->isCallee(&OI)))
           RefEdges.insert(Index.getOrInsertValueInfo(GV));
         continue;
       }
@@ -145,7 +144,7 @@ static void addVCallToSet(DevirtCallSite Call, GlobalValue::GUID Guid,
                           SetVector<FunctionSummary::ConstVCall> &ConstVCalls) {
   std::vector<uint64_t> Args;
   // Start from the second argument to skip the "this" pointer.
-  for (auto &Arg : make_range(Call.CS.arg_begin() + 1, Call.CS.arg_end())) {
+  for (auto &Arg : make_range(Call.CB.arg_begin() + 1, Call.CB.arg_end())) {
     auto *CI = dyn_cast<ConstantInt>(Arg);
     if (!CI || CI->getBitWidth() > 64) {
       VCalls.insert({Guid, Call.Offset});
@@ -304,8 +303,8 @@ static void computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M,
         }
       }
       findRefEdges(Index, &I, RefEdges, Visited);
-      auto CS = ImmutableCallSite(&I);
-      if (!CS)
+      const auto *CB = dyn_cast<CallBase>(&I);
+      if (!CB)
         continue;
 
       const auto *CI = dyn_cast<CallInst>(&I);
@@ -317,8 +316,8 @@ static void computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M,
       if (HasLocalsInUsedOrAsm && CI && CI->isInlineAsm())
         HasInlineAsmMaybeReferencingInternal = true;
 
-      auto *CalledValue = CS.getCalledValue();
-      auto *CalledFunction = CS.getCalledFunction();
+      auto *CalledValue = CB->getCalledValue();
+      auto *CalledFunction = CB->getCalledFunction();
       if (CalledValue && !CalledFunction) {
         CalledValue = CalledValue->stripPointerCasts();
         // Stripping pointer casts can reveal a called function.

diff  --git a/llvm/lib/Analysis/TypeMetadataUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp
index 072d291f3f93..f187d0796771 100644
--- a/llvm/lib/Analysis/TypeMetadataUtils.cpp
+++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp
@@ -37,10 +37,10 @@ findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls,
     if (isa<BitCastInst>(User)) {
       findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset, CI,
                                 DT);
-    } else if (auto CI = dyn_cast<CallInst>(User)) {
-      DevirtCalls.push_back({Offset, CI});
-    } else if (auto II = dyn_cast<InvokeInst>(User)) {
-      DevirtCalls.push_back({Offset, II});
+    } else if (auto *CI = dyn_cast<CallInst>(User)) {
+      DevirtCalls.push_back({Offset, *CI});
+    } else if (auto *II = dyn_cast<InvokeInst>(User)) {
+      DevirtCalls.push_back({Offset, *II});
     } else if (HasNonCallUses) {
       *HasNonCallUses = true;
     }

diff  --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 60f8e935ffdd..2e72b2981d83 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -64,7 +64,6 @@
 #include "llvm/Analysis/TypeMetadataUtils.h"
 #include "llvm/Bitcode/BitcodeReader.h"
 #include "llvm/Bitcode/BitcodeWriter.h"
-#include "llvm/IR/CallSite.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugLoc.h"
@@ -354,20 +353,20 @@ namespace {
 // A virtual call site. VTable is the loaded virtual table pointer, and CS is
 // the indirect virtual call.
 struct VirtualCallSite {
-  Value *VTable;
-  CallSite CS;
+  Value *VTable = nullptr;
+  CallBase &CB;
 
   // If non-null, this field points to the associated unsafe use count stored in
   // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
   // of that field for details.
-  unsigned *NumUnsafeUses;
+  unsigned *NumUnsafeUses = nullptr;
 
   void
   emitRemark(const StringRef OptName, const StringRef TargetName,
              function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
-    Function *F = CS.getCaller();
-    DebugLoc DLoc = CS->getDebugLoc();
-    BasicBlock *Block = CS.getParent();
+    Function *F = CB.getCaller();
+    DebugLoc DLoc = CB.getDebugLoc();
+    BasicBlock *Block = CB.getParent();
 
     using namespace ore;
     OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block)
@@ -382,12 +381,12 @@ struct VirtualCallSite {
       Value *New) {
     if (RemarksEnabled)
       emitRemark(OptName, TargetName, OREGetter);
-    CS->replaceAllUsesWith(New);
-    if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
-      BranchInst::Create(II->getNormalDest(), CS.getInstruction());
+    CB.replaceAllUsesWith(New);
+    if (auto *II = dyn_cast<InvokeInst>(&CB)) {
+      BranchInst::Create(II->getNormalDest(), &CB);
       II->getUnwindDest()->removePredecessor(II->getParent());
     }
-    CS->eraseFromParent();
+    CB.eraseFromParent();
     // This use is no longer unsafe.
     if (NumUnsafeUses)
       --*NumUnsafeUses;
@@ -460,18 +459,18 @@ struct VTableSlotInfo {
   // "this"), grouped by argument list.
   std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
 
-  void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
+  void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses);
 
 private:
-  CallSiteInfo &findCallSiteInfo(CallSite CS);
+  CallSiteInfo &findCallSiteInfo(CallBase &CB);
 };
 
-CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
+CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) {
   std::vector<uint64_t> Args;
-  auto *CI = dyn_cast<IntegerType>(CS.getType());
-  if (!CI || CI->getBitWidth() > 64 || CS.arg_empty())
+  auto *CBType = dyn_cast<IntegerType>(CB.getType());
+  if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty())
     return CSInfo;
-  for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
+  for (auto &&Arg : make_range(CB.arg_begin() + 1, CB.arg_end())) {
     auto *CI = dyn_cast<ConstantInt>(Arg);
     if (!CI || CI->getBitWidth() > 64)
       return CSInfo;
@@ -480,11 +479,11 @@ CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
   return ConstCSInfo[Args];
 }
 
-void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
+void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB,
                                  unsigned *NumUnsafeUses) {
-  auto &CSI = findCallSiteInfo(CS);
+  auto &CSI = findCallSiteInfo(CB);
   CSI.AllCallSitesDevirted = false;
-  CSI.CallSites.push_back({VTable, CS, NumUnsafeUses});
+  CSI.CallSites.push_back({VTable, CB, NumUnsafeUses});
 }
 
 struct DevirtModule {
@@ -1029,8 +1028,8 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
       if (RemarksEnabled)
         VCallSite.emitRemark("single-impl",
                              TheFn->stripPointerCasts()->getName(), OREGetter);
-      VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
-          TheFn, VCallSite.CS.getCalledValue()->getType()));
+      VCallSite.CB.setCalledOperand(ConstantExpr::getBitCast(
+          TheFn, VCallSite.CB.getCalledValue()->getType()));
       // This use is no longer unsafe.
       if (VCallSite.NumUnsafeUses)
         --*VCallSite.NumUnsafeUses;
@@ -1253,10 +1252,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
     if (CSInfo.AllCallSitesDevirted)
       return;
     for (auto &&VCallSite : CSInfo.CallSites) {
-      CallSite CS = VCallSite.CS;
+      CallBase &CB = VCallSite.CB;
 
       // Jump tables are only profitable if the retpoline mitigation is enabled.
-      Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features");
+      Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features");
       if (FSAttr.hasAttribute(Attribute::None) ||
           !FSAttr.getValueAsString().contains("+retpoline"))
         continue;
@@ -1269,42 +1268,40 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
       // x86_64.
       std::vector<Type *> NewArgs;
       NewArgs.push_back(Int8PtrTy);
-      for (Type *T : CS.getFunctionType()->params())
+      for (Type *T : CB.getFunctionType()->params())
         NewArgs.push_back(T);
       FunctionType *NewFT =
-          FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs,
-                            CS.getFunctionType()->isVarArg());
+          FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs,
+                            CB.getFunctionType()->isVarArg());
       PointerType *NewFTPtr = PointerType::getUnqual(NewFT);
 
-      IRBuilder<> IRB(CS.getInstruction());
+      IRBuilder<> IRB(&CB);
       std::vector<Value *> Args;
       Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
-      for (unsigned I = 0; I != CS.getNumArgOperands(); ++I)
-        Args.push_back(CS.getArgOperand(I));
+      Args.insert(Args.end(), CB.arg_begin(), CB.arg_end());
 
-      CallSite NewCS;
-      if (CS.isCall())
+      CallBase *NewCS = nullptr;
+      if (isa<CallInst>(CB))
         NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args);
       else
-        NewCS = IRB.CreateInvoke(
-            NewFT, IRB.CreateBitCast(JT, NewFTPtr),
-            cast<InvokeInst>(CS.getInstruction())->getNormalDest(),
-            cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args);
-      NewCS.setCallingConv(CS.getCallingConv());
+        NewCS = IRB.CreateInvoke(NewFT, IRB.CreateBitCast(JT, NewFTPtr),
+                                 cast<InvokeInst>(CB).getNormalDest(),
+                                 cast<InvokeInst>(CB).getUnwindDest(), Args);
+      NewCS->setCallingConv(CB.getCallingConv());
 
-      AttributeList Attrs = CS.getAttributes();
+      AttributeList Attrs = CB.getAttributes();
       std::vector<AttributeSet> NewArgAttrs;
       NewArgAttrs.push_back(AttributeSet::get(
           M.getContext(), ArrayRef<Attribute>{Attribute::get(
                               M.getContext(), Attribute::Nest)}));
       for (unsigned I = 0; I + 2 <  Attrs.getNumAttrSets(); ++I)
         NewArgAttrs.push_back(Attrs.getParamAttributes(I));
-      NewCS.setAttributes(
+      NewCS->setAttributes(
           AttributeList::get(M.getContext(), Attrs.getFnAttributes(),
                              Attrs.getRetAttributes(), NewArgAttrs));
 
-      CS->replaceAllUsesWith(NewCS.getInstruction());
-      CS->eraseFromParent();
+      CB.replaceAllUsesWith(NewCS);
+      CB.eraseFromParent();
 
       // This use is no longer unsafe.
       if (VCallSite.NumUnsafeUses)
@@ -1355,7 +1352,7 @@ void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
   for (auto Call : CSInfo.CallSites)
     Call.replaceAndErase(
         "uniform-ret-val", FnName, RemarksEnabled, OREGetter,
-        ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal));
+        ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal));
   CSInfo.markDevirt();
 }
 
@@ -1461,11 +1458,11 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
                                         bool IsOne,
                                         Constant *UniqueMemberAddr) {
   for (auto &&Call : CSInfo.CallSites) {
-    IRBuilder<> B(Call.CS.getInstruction());
+    IRBuilder<> B(&Call.CB);
     Value *Cmp =
         B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable,
                      B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType()));
-    Cmp = B.CreateZExt(Cmp, Call.CS->getType());
+    Cmp = B.CreateZExt(Cmp, Call.CB.getType());
     Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter,
                          Cmp);
   }
@@ -1529,8 +1526,8 @@ bool DevirtModule::tryUniqueRetValOpt(
 void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
                                          Constant *Byte, Constant *Bit) {
   for (auto Call : CSInfo.CallSites) {
-    auto *RetType = cast<IntegerType>(Call.CS.getType());
-    IRBuilder<> B(Call.CS.getInstruction());
+    auto *RetType = cast<IntegerType>(Call.CB.getType());
+    IRBuilder<> B(&Call.CB);
     Value *Addr =
         B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte);
     if (RetType->getBitWidth() == 1) {
@@ -1716,7 +1713,7 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc) {
   // points to a member of the type identifier %md. Group calls by (type ID,
   // offset) pair (effectively the identity of the virtual function) and store
   // to CallSlots.
-  DenseSet<CallSite> SeenCallSites;
+  DenseSet<CallBase *> SeenCallSites;
   for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
        I != E;) {
     auto CI = dyn_cast<CallInst>(I->getUser());
@@ -1741,8 +1738,8 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc) {
         // and we don't want to process call sites multiple times. We can't
         // just skip the vtable Ptr if it has been seen before, however, since
         // it may be shared by type tests that dominate 
diff erent calls.
-        if (SeenCallSites.insert(Call.CS).second)
-          CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr);
+        if (SeenCallSites.insert(&Call.CB).second)
+          CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr);
       }
     }
 
@@ -1828,7 +1825,7 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
     if (HasNonCallUses)
       ++NumUnsafeUses;
     for (DevirtCallSite Call : DevirtCalls) {
-      CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
+      CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB,
                                                    &NumUnsafeUses);
     }
 


        


More information about the llvm-commits mailing list