[llvm] d2f1cd5 - [llvm][NFC] Refactor uses of CallSite to CallBase - call promotion

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 12 08:29:28 PDT 2020


Author: Mircea Trofin
Date: 2020-04-12T08:27:29-07:00
New Revision: d2f1cd5d9712276730f09745825fb6d71c51e639

URL: https://github.com/llvm/llvm-project/commit/d2f1cd5d9712276730f09745825fb6d71c51e639
DIFF: https://github.com/llvm/llvm-project/commit/d2f1cd5d9712276730f09745825fb6d71c51e639.diff

LOG: [llvm][NFC] Refactor uses of CallSite to CallBase - call promotion

Summary:
Updated CallPromotionUtils and impacted sites. Parameters that are
expected to be non-null, and return values that are guranteed non-null,
were replaced with CallBase references rather than pointers.

Left FIXME in places where more changes are facilitated by CallBase, but
aren't CallSites: Instruction* parameters or return values, for example,
where the contract that they are actually CallBase values.

Reviewers: davidxl, dblaikie, wmi

Reviewed By: dblaikie

Subscribers: arsenm, jvesely, nhaehnle, eraman, hiraditya, kerbowa, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77930

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
    llvm/include/llvm/Transforms/Utils/Cloning.h
    llvm/lib/Target/AMDGPU/AMDGPUFixFunctionBitcasts.cpp
    llvm/lib/Transforms/IPO/Inliner.cpp
    llvm/lib/Transforms/IPO/SampleProfile.cpp
    llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
    llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
    llvm/lib/Transforms/Utils/InlineFunction.cpp
    llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
index 938c5d2ea469..693550192369 100644
--- a/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
+++ b/llvm/include/llvm/Transforms/Utils/CallPromotionUtils.h
@@ -14,9 +14,11 @@
 #ifndef LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
 #define LLVM_TRANSFORMS_UTILS_CALLPROMOTIONUTILS_H
 
-#include "llvm/IR/CallSite.h"
-
 namespace llvm {
+class CallBase;
+class CastInst;
+class Function;
+class MDNode;
 
 /// Return true if the given indirect call site can be made to call \p Callee.
 ///
@@ -25,7 +27,7 @@ namespace llvm {
 /// match exactly, they must at least be bitcast compatible. If \p FailureReason
 /// is non-null and the indirect call cannot be promoted, the failure reason
 /// will be stored in it.
-bool isLegalToPromote(CallSite CS, Function *Callee,
+bool isLegalToPromote(CallBase &CB, Function *Callee,
                       const char **FailureReason = nullptr);
 
 /// Promote the given indirect call site to unconditionally call \p Callee.
@@ -35,8 +37,8 @@ bool isLegalToPromote(CallSite CS, Function *Callee,
 /// of the callee, bitcast instructions are inserted where appropriate. If \p
 /// RetBitCast is non-null, it will be used to store the return value bitcast,
 /// if created.
-Instruction *promoteCall(CallSite CS, Function *Callee,
-                         CastInst **RetBitCast = nullptr);
+CallBase &promoteCall(CallBase &CB, Function *Callee,
+                      CastInst **RetBitCast = nullptr);
 
 /// Promote the given indirect call site to conditionally call \p Callee.
 ///
@@ -45,8 +47,8 @@ Instruction *promoteCall(CallSite CS, Function *Callee,
 /// indirect call site is promoted, placed in the "then" block, and returned. If
 /// \p BranchWeights is non-null, it will be used to set !prof metadata on the
 /// new conditional branch.
-Instruction *promoteCallWithIfThenElse(CallSite CS, Function *Callee,
-                                       MDNode *BranchWeights = nullptr);
+CallBase &promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
+                                    MDNode *BranchWeights = nullptr);
 
 /// Try to promote (devirtualize) a virtual call on an Alloca. Return true on
 /// success.
@@ -69,7 +71,7 @@ Instruction *promoteCallWithIfThenElse(CallSite CS, Function *Callee,
 ///     [i8* null, i8* bitcast ({ i8*, i8*, i8* }* @_ZTI4Impl to i8*),
 ///     i8* bitcast (void (%class.Impl*)* @_ZN4Impl3RunEv to i8*)] }
 ///
-bool tryPromoteCall(CallSite &CS);
+bool tryPromoteCall(CallBase &CB);
 
 } // end namespace llvm
 

diff  --git a/llvm/include/llvm/Transforms/Utils/Cloning.h b/llvm/include/llvm/Transforms/Utils/Cloning.h
index 872ab9cab85c..a939f192e418 100644
--- a/llvm/include/llvm/Transforms/Utils/Cloning.h
+++ b/llvm/include/llvm/Transforms/Utils/Cloning.h
@@ -22,7 +22,6 @@
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/InlineCost.h"
-#include "llvm/IR/CallSite.h"
 #include "llvm/IR/ValueHandle.h"
 #include "llvm/Transforms/Utils/ValueMapper.h"
 #include <functional>
@@ -201,7 +200,7 @@ class InlineFunctionInfo {
   /// 'InlineFunction' fills this in by scanning the inlined instructions, and
   /// only if CG is null. If CG is non-null, instead the value handle
   /// `InlinedCalls` above is used.
-  SmallVector<CallSite, 8> InlinedCallSites;
+  SmallVector<CallBase *, 8> InlinedCallSites;
 
   void reset() {
     StaticAllocas.clear();

diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUFixFunctionBitcasts.cpp b/llvm/lib/Target/AMDGPU/AMDGPUFixFunctionBitcasts.cpp
index 9ba04d113c70..2c23a163e70b 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUFixFunctionBitcasts.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUFixFunctionBitcasts.cpp
@@ -35,8 +35,9 @@ class AMDGPUFixFunctionBitcasts final
     if (CS.getCalledFunction())
       return;
     auto Callee = dyn_cast<Function>(CS.getCalledValue()->stripPointerCasts());
-    if (Callee && isLegalToPromote(CS, Callee)) {
-      promoteCall(CS, Callee);
+    if (Callee &&
+        isLegalToPromote(*cast<CallBase>(CS.getInstruction()), Callee)) {
+      promoteCall(*cast<CallBase>(CS.getInstruction()), Callee);
       Modified = true;
     }
   }

diff  --git a/llvm/lib/Transforms/IPO/Inliner.cpp b/llvm/lib/Transforms/IPO/Inliner.cpp
index 6a4e1197bf8e..dc02305c90a5 100644
--- a/llvm/lib/Transforms/IPO/Inliner.cpp
+++ b/llvm/lib/Transforms/IPO/Inliner.cpp
@@ -38,7 +38,6 @@
 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CallSite.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugLoc.h"
 #include "llvm/IR/DerivedTypes.h"
@@ -1113,20 +1112,19 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC,
         InlineHistory.push_back({&Callee, InlineHistoryID});
 
         // FIXME(mtrofin): refactor IFI.InlinedCallSites to be CallBase-based
-        for (CallSite &CS : reverse(IFI.InlinedCallSites)) {
-          Function *NewCallee = CS.getCalledFunction();
+        for (CallBase *CS : reverse(IFI.InlinedCallSites)) {
+          Function *NewCallee = CS->getCalledFunction();
           if (!NewCallee) {
             // Try to promote an indirect (virtual) call without waiting for the
             // post-inline cleanup and the next DevirtSCCRepeatedPass iteration
             // because the next iteration may not happen and we may miss
             // inlining it.
-            if (tryPromoteCall(CS))
-              NewCallee = CS.getCalledFunction();
+            if (tryPromoteCall(*CS))
+              NewCallee = CS->getCalledFunction();
           }
           if (NewCallee)
             if (!NewCallee->isDeclaration())
-              Calls.push_back(
-                  {cast<CallBase>(CS.getInstruction()), NewHistoryID});
+              Calls.push_back({CS, NewHistoryID});
         }
       }
 

diff  --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 9c1369fdcf99..9db721602e20 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -46,7 +46,6 @@
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
-#include "llvm/IR/CallSite.h"
 #include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/IR/DebugLoc.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -984,6 +983,8 @@ bool SampleProfileLoader::inlineHotFunctions(
          "ProfAccForSymsInList should be false when profile-sample-accurate "
          "is enabled");
 
+  // FIXME(CallSite): refactor the vectors here, as they operate with CallBase
+  // values
   DenseMap<Instruction *, const FunctionSamples *> localNotInlinedCallSites;
   bool Changed = false;
   while (true) {
@@ -1047,7 +1048,7 @@ bool SampleProfileLoader::inlineHotFunctions(
           if (R != SymbolMap.end() && R->getValue() &&
               !R->getValue()->isDeclaration() &&
               R->getValue()->getSubprogram() &&
-              isLegalToPromote(CallSite(I), R->getValue(), &Reason)) {
+              isLegalToPromote(*cast<CallBase>(I), R->getValue(), &Reason)) {
             uint64_t C = FS->getEntrySamples();
             Instruction *DI =
                 pgo::promoteIndirectCall(I, R->getValue(), C, Sum, false, ORE);

diff  --git a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
index d5787c8f62a1..9857769e880f 100644
--- a/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
+++ b/llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp
@@ -23,7 +23,6 @@
 #include "llvm/Analysis/ProfileSummaryInfo.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CallSite.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/Function.h"
@@ -217,6 +216,7 @@ class ICallPromotionFunc {
 
 // Indirect-call promotion heuristic. The direct targets are sorted based on
 // the count. Stop at the first target that is not promoted.
+// FIXME(callsite): the Instruction* parameter can be changed to CallBase
 std::vector<ICallPromotionFunc::PromotionCandidate>
 ICallPromotionFunc::getPromotionCandidatesForCallSite(
     Instruction *Inst, const ArrayRef<InstrProfValueData> &ValueDataRef,
@@ -276,7 +276,7 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite(
     }
 
     const char *Reason = nullptr;
-    if (!isLegalToPromote(CallSite(Inst), TargetFunction, &Reason)) {
+    if (!isLegalToPromote(*cast<CallBase>(Inst), TargetFunction, &Reason)) {
       using namespace ore;
 
       ORE.emit([&]() {
@@ -294,6 +294,8 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite(
   return Ret;
 }
 
+// FIXME(callsite): the Instruction* parameter and return can be changed to
+// CallBase
 Instruction *llvm::pgo::promoteIndirectCall(Instruction *Inst,
                                             Function *DirectCallee,
                                             uint64_t Count, uint64_t TotalCount,
@@ -307,12 +309,12 @@ Instruction *llvm::pgo::promoteIndirectCall(Instruction *Inst,
   MDNode *BranchWeights = MDB.createBranchWeights(
       scaleBranchCount(Count, Scale), scaleBranchCount(ElseCount, Scale));
 
-  Instruction *NewInst =
-      promoteCallWithIfThenElse(CallSite(Inst), DirectCallee, BranchWeights);
+  CallBase &NewInst = promoteCallWithIfThenElse(*cast<CallBase>(Inst),
+                                                DirectCallee, BranchWeights);
 
   if (AttachProfToDirectCall) {
-    MDBuilder MDB(NewInst->getContext());
-    NewInst->setMetadata(
+    MDBuilder MDB(NewInst.getContext());
+    NewInst.setMetadata(
         LLVMContext::MD_prof,
         MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
   }
@@ -326,7 +328,7 @@ Instruction *llvm::pgo::promoteIndirectCall(Instruction *Inst,
              << " with count " << NV("Count", Count) << " out of "
              << NV("TotalCount", TotalCount);
     });
-  return NewInst;
+  return &NewInst;
 }
 
 // Promote indirect-call to conditional direct-call for one callsite.

diff  --git a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
index 19e7efe328fc..d087717d1ecf 100644
--- a/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
+++ b/llvm/lib/Transforms/Utils/CallPromotionUtils.cpp
@@ -15,6 +15,7 @@
 #include "llvm/Analysis/Loads.h"
 #include "llvm/Analysis/TypeMetadataUtils.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 
 using namespace llvm;
@@ -160,32 +161,31 @@ static void createRetPHINode(Instruction *OrigInst, Instruction *NewInst,
 ///     %t1 = bitcast i32 %t0 to ...
 ///     br label %normal_dst
 ///
-static void createRetBitCast(CallSite CS, Type *RetTy, CastInst **RetBitCast) {
+static void createRetBitCast(CallBase &CB, Type *RetTy, CastInst **RetBitCast) {
 
   // Save the users of the calling instruction. These uses will be changed to
   // use the bitcast after we create it.
   SmallVector<User *, 16> UsersToUpdate;
-  for (User *U : CS.getInstruction()->users())
+  for (User *U : CB.users())
     UsersToUpdate.push_back(U);
 
   // Determine an appropriate location to create the bitcast for the return
   // value. The location depends on if we have a call or invoke instruction.
   Instruction *InsertBefore = nullptr;
-  if (auto *Invoke = dyn_cast<InvokeInst>(CS.getInstruction()))
+  if (auto *Invoke = dyn_cast<InvokeInst>(&CB))
     InsertBefore =
         &SplitEdge(Invoke->getParent(), Invoke->getNormalDest())->front();
   else
-    InsertBefore = &*std::next(CS.getInstruction()->getIterator());
+    InsertBefore = &*std::next(CB.getIterator());
 
   // Bitcast the return value to the correct type.
-  auto *Cast = CastInst::CreateBitOrPointerCast(CS.getInstruction(), RetTy, "",
-                                                InsertBefore);
+  auto *Cast = CastInst::CreateBitOrPointerCast(&CB, RetTy, "", InsertBefore);
   if (RetBitCast)
     *RetBitCast = Cast;
 
   // Replace all the original uses of the calling instruction with the bitcast.
   for (User *U : UsersToUpdate)
-    U->replaceUsesOfWith(CS.getInstruction(), Cast);
+    U->replaceUsesOfWith(&CB, Cast);
 }
 
 /// Predicate and clone the given call site.
@@ -255,26 +255,25 @@ static void createRetBitCast(CallSite CS, Type *RetTy, CastInst **RetBitCast) {
 ///     %t2 = phi i32 [ %t0, %else_bb ], [ %t1, %then_bb ]
 ///     br %normal_dst
 ///
-static Instruction *versionCallSite(CallSite CS, Value *Callee,
-                                    MDNode *BranchWeights) {
+static CallBase &versionCallSite(CallBase &CB, Value *Callee,
+                                 MDNode *BranchWeights) {
 
-  IRBuilder<> Builder(CS.getInstruction());
-  Instruction *OrigInst = CS.getInstruction();
+  IRBuilder<> Builder(&CB);
+  CallBase *OrigInst = &CB;
   BasicBlock *OrigBlock = OrigInst->getParent();
 
   // Create the compare. The called value and callee must have the same type to
   // be compared.
-  if (CS.getCalledValue()->getType() != Callee->getType())
-    Callee = Builder.CreateBitCast(Callee, CS.getCalledValue()->getType());
-  auto *Cond = Builder.CreateICmpEQ(CS.getCalledValue(), Callee);
+  if (CB.getCalledValue()->getType() != Callee->getType())
+    Callee = Builder.CreateBitCast(Callee, CB.getCalledValue()->getType());
+  auto *Cond = Builder.CreateICmpEQ(CB.getCalledValue(), Callee);
 
   // Create an if-then-else structure. The original instruction is moved into
   // the "else" block, and a clone of the original instruction is placed in the
   // "then" block.
   Instruction *ThenTerm = nullptr;
   Instruction *ElseTerm = nullptr;
-  SplitBlockAndInsertIfThenElse(Cond, CS.getInstruction(), &ThenTerm, &ElseTerm,
-                                BranchWeights);
+  SplitBlockAndInsertIfThenElse(Cond, &CB, &ThenTerm, &ElseTerm, BranchWeights);
   BasicBlock *ThenBlock = ThenTerm->getParent();
   BasicBlock *ElseBlock = ElseTerm->getParent();
   BasicBlock *MergeBlock = OrigInst->getParent();
@@ -283,7 +282,7 @@ static Instruction *versionCallSite(CallSite CS, Value *Callee,
   ElseBlock->setName("if.false.orig_indirect");
   MergeBlock->setName("if.end.icp");
 
-  Instruction *NewInst = OrigInst->clone();
+  CallBase *NewInst = cast<CallBase>(OrigInst->clone());
   OrigInst->moveBefore(ElseTerm);
   NewInst->insertBefore(ThenTerm);
 
@@ -315,18 +314,18 @@ static Instruction *versionCallSite(CallSite CS, Value *Callee,
   // Create a phi node for the returned value of the call site.
   createRetPHINode(OrigInst, NewInst, MergeBlock, Builder);
 
-  return NewInst;
+  return *NewInst;
 }
 
-bool llvm::isLegalToPromote(CallSite CS, Function *Callee,
+bool llvm::isLegalToPromote(CallBase &CB, Function *Callee,
                             const char **FailureReason) {
-  assert(!CS.getCalledFunction() && "Only indirect call sites can be promoted");
+  assert(!CB.getCalledFunction() && "Only indirect call sites can be promoted");
 
   auto &DL = Callee->getParent()->getDataLayout();
 
   // Check the return type. The callee's return value type must be bitcast
   // compatible with the call site's type.
-  Type *CallRetTy = CS.getInstruction()->getType();
+  Type *CallRetTy = CB.getType();
   Type *FuncRetTy = Callee->getReturnType();
   if (CallRetTy != FuncRetTy)
     if (!CastInst::isBitOrNoopPointerCastable(FuncRetTy, CallRetTy, DL)) {
@@ -340,7 +339,7 @@ bool llvm::isLegalToPromote(CallSite CS, Function *Callee,
 
   // Check the number of arguments. The callee and call site must agree on the
   // number of arguments.
-  if (CS.arg_size() != NumParams && !Callee->isVarArg()) {
+  if (CB.arg_size() != NumParams && !Callee->isVarArg()) {
     if (FailureReason)
       *FailureReason = "The number of arguments mismatch";
     return false;
@@ -351,7 +350,7 @@ bool llvm::isLegalToPromote(CallSite CS, Function *Callee,
   // site.
   for (unsigned I = 0; I < NumParams; ++I) {
     Type *FormalTy = Callee->getFunctionType()->getFunctionParamType(I);
-    Type *ActualTy = CS.getArgument(I)->getType();
+    Type *ActualTy = CB.getArgOperand(I)->getType();
     if (FormalTy == ActualTy)
       continue;
     if (!CastInst::isBitOrNoopPointerCastable(ActualTy, FormalTy, DL)) {
@@ -364,31 +363,31 @@ bool llvm::isLegalToPromote(CallSite CS, Function *Callee,
   return true;
 }
 
-Instruction *llvm::promoteCall(CallSite CS, Function *Callee,
-                               CastInst **RetBitCast) {
-  assert(!CS.getCalledFunction() && "Only indirect call sites can be promoted");
+CallBase &llvm::promoteCall(CallBase &CB, Function *Callee,
+                            CastInst **RetBitCast) {
+  assert(!CB.getCalledFunction() && "Only indirect call sites can be promoted");
 
   // Set the called function of the call site to be the given callee (but don't
   // change the type).
-  cast<CallBase>(CS.getInstruction())->setCalledOperand(Callee);
+  CB.setCalledOperand(Callee);
 
   // Since the call site will no longer be direct, we must clear metadata that
   // is only appropriate for indirect calls. This includes !prof and !callees
   // metadata.
-  CS.getInstruction()->setMetadata(LLVMContext::MD_prof, nullptr);
-  CS.getInstruction()->setMetadata(LLVMContext::MD_callees, nullptr);
+  CB.setMetadata(LLVMContext::MD_prof, nullptr);
+  CB.setMetadata(LLVMContext::MD_callees, nullptr);
 
   // If the function type of the call site matches that of the callee, no
   // additional work is required.
-  if (CS.getFunctionType() == Callee->getFunctionType())
-    return CS.getInstruction();
+  if (CB.getFunctionType() == Callee->getFunctionType())
+    return CB;
 
   // Save the return types of the call site and callee.
-  Type *CallSiteRetTy = CS.getInstruction()->getType();
+  Type *CallSiteRetTy = CB.getType();
   Type *CalleeRetTy = Callee->getReturnType();
 
   // Change the function type of the call site the match that of the callee.
-  CS.mutateFunctionType(Callee->getFunctionType());
+  CB.mutateFunctionType(Callee->getFunctionType());
 
   // Inspect the arguments of the call site. If an argument's type doesn't
   // match the corresponding formal argument's type in the callee, bitcast it
@@ -397,19 +396,18 @@ Instruction *llvm::promoteCall(CallSite CS, Function *Callee,
   auto CalleeParamNum = CalleeType->getNumParams();
 
   LLVMContext &Ctx = Callee->getContext();
-  const AttributeList &CallerPAL = CS.getAttributes();
+  const AttributeList &CallerPAL = CB.getAttributes();
   // The new list of argument attributes.
   SmallVector<AttributeSet, 4> NewArgAttrs;
   bool AttributeChanged = false;
 
   for (unsigned ArgNo = 0; ArgNo < CalleeParamNum; ++ArgNo) {
-    auto *Arg = CS.getArgument(ArgNo);
+    auto *Arg = CB.getArgOperand(ArgNo);
     Type *FormalTy = CalleeType->getParamType(ArgNo);
     Type *ActualTy = Arg->getType();
     if (FormalTy != ActualTy) {
-      auto *Cast = CastInst::CreateBitOrPointerCast(Arg, FormalTy, "",
-                                                    CS.getInstruction());
-      CS.setArgument(ArgNo, Cast);
+      auto *Cast = CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", &CB);
+      CB.setArgOperand(ArgNo, Cast);
 
       // Remove any incompatible attributes for the argument.
       AttrBuilder ArgAttrs(CallerPAL.getParamAttributes(ArgNo));
@@ -434,37 +432,37 @@ Instruction *llvm::promoteCall(CallSite CS, Function *Callee,
   // Remove any incompatible return value attribute.
   AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex);
   if (!CallSiteRetTy->isVoidTy() && CallSiteRetTy != CalleeRetTy) {
-    createRetBitCast(CS, CallSiteRetTy, RetBitCast);
+    createRetBitCast(CB, CallSiteRetTy, RetBitCast);
     RAttrs.remove(AttributeFuncs::typeIncompatible(CalleeRetTy));
     AttributeChanged = true;
   }
 
   // Set the new callsite attribute.
   if (AttributeChanged)
-    CS.setAttributes(AttributeList::get(Ctx, CallerPAL.getFnAttributes(),
+    CB.setAttributes(AttributeList::get(Ctx, CallerPAL.getFnAttributes(),
                                         AttributeSet::get(Ctx, RAttrs),
                                         NewArgAttrs));
 
-  return CS.getInstruction();
+  return CB;
 }
 
-Instruction *llvm::promoteCallWithIfThenElse(CallSite CS, Function *Callee,
-                                             MDNode *BranchWeights) {
+CallBase &llvm::promoteCallWithIfThenElse(CallBase &CB, Function *Callee,
+                                          MDNode *BranchWeights) {
 
   // Version the indirect call site. If the called value is equal to the given
   // callee, 'NewInst' will be executed, otherwise the original call site will
   // be executed.
-  Instruction *NewInst = versionCallSite(CS, Callee, BranchWeights);
+  CallBase &NewInst = versionCallSite(CB, Callee, BranchWeights);
 
   // Promote 'NewInst' so that it directly calls the desired function.
-  return promoteCall(CallSite(NewInst), Callee);
+  return promoteCall(NewInst, Callee);
 }
 
-bool llvm::tryPromoteCall(CallSite &CS) {
-  assert(!CS.getCalledFunction());
-  Module *M = CS.getCaller()->getParent();
+bool llvm::tryPromoteCall(CallBase &CB) {
+  assert(!CB.getCalledFunction());
+  Module *M = CB.getCaller()->getParent();
   const DataLayout &DL = M->getDataLayout();
-  Value *Callee = CS.getCalledValue();
+  Value *Callee = CB.getCalledValue();
 
   LoadInst *VTableEntryLoad = dyn_cast<LoadInst>(Callee);
   if (!VTableEntryLoad)
@@ -511,11 +509,11 @@ bool llvm::tryPromoteCall(CallSite &CS) {
   if (!DirectCallee)
     return false; // No function pointer found.
 
-  if (!isLegalToPromote(CS, DirectCallee))
+  if (!isLegalToPromote(CB, DirectCallee))
     return false;
 
   // Success.
-  promoteCall(CS, DirectCallee);
+  promoteCall(CB, DirectCallee);
   return true;
 }
 

diff  --git a/llvm/lib/Transforms/Utils/InlineFunction.cpp b/llvm/lib/Transforms/Utils/InlineFunction.cpp
index 2c86af68d646..d7dd342004e2 100644
--- a/llvm/lib/Transforms/Utils/InlineFunction.cpp
+++ b/llvm/lib/Transforms/Utils/InlineFunction.cpp
@@ -34,7 +34,6 @@
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
-#include "llvm/IR/CallSite.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DIBuilder.h"
@@ -2350,8 +2349,8 @@ llvm::InlineResult llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI,
     for (BasicBlock &NewBB :
          make_range(FirstNewBlock->getIterator(), Caller->end()))
       for (Instruction &I : NewBB)
-        if (auto CS = CallSite(&I))
-          IFI.InlinedCallSites.push_back(CS);
+        if (auto *CB = dyn_cast<CallBase>(&I))
+          IFI.InlinedCallSites.push_back(CB);
   }
 
   // If we cloned in _exactly one_ basic block, and if that block ends in a

diff  --git a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
index 18db4623f368..eff8e27d36d6 100644
--- a/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CallPromotionUtilsTest.cpp
@@ -61,14 +61,13 @@ declare void @_ZN4Impl3RunEv(%class.Impl* %this)
   Inst = &*++F->front().rbegin();
   auto *CI = dyn_cast<CallInst>(Inst);
   ASSERT_TRUE(CI);
-  CallSite CS(CI);
-  ASSERT_FALSE(CS.getCalledFunction());
-  bool IsPromoted = tryPromoteCall(CS);
+  ASSERT_FALSE(CI->getCalledFunction());
+  bool IsPromoted = tryPromoteCall(*CI);
   EXPECT_TRUE(IsPromoted);
   GV = M->getNamedValue("_ZN4Impl3RunEv");
   ASSERT_TRUE(GV);
   auto *F1 = dyn_cast<Function>(GV);
-  EXPECT_EQ(F1, CS.getCalledFunction());
+  EXPECT_EQ(F1, CI->getCalledFunction());
 }
 
 TEST(CallPromotionUtilsTest, TryPromoteCall_NoFPLoad) {
@@ -92,9 +91,8 @@ define void @f(void (%class.Interface*)* %fp, %class.Interface* nonnull %base.i)
   Instruction *Inst = &F->front().front();
   auto *CI = dyn_cast<CallInst>(Inst);
   ASSERT_TRUE(CI);
-  CallSite CS(CI);
-  ASSERT_FALSE(CS.getCalledFunction());
-  bool IsPromoted = tryPromoteCall(CS);
+  ASSERT_FALSE(CI->getCalledFunction());
+  bool IsPromoted = tryPromoteCall(*CI);
   EXPECT_FALSE(IsPromoted);
 }
 
@@ -120,9 +118,8 @@ define void @f(void (%class.Interface*)** %vtable.i, %class.Interface* nonnull %
   Instruction *Inst = &*++F->front().rbegin();
   auto *CI = dyn_cast<CallInst>(Inst);
   ASSERT_TRUE(CI);
-  CallSite CS(CI);
-  ASSERT_FALSE(CS.getCalledFunction());
-  bool IsPromoted = tryPromoteCall(CS);
+  ASSERT_FALSE(CI->getCalledFunction());
+  bool IsPromoted = tryPromoteCall(*CI);
   EXPECT_FALSE(IsPromoted);
 }
 
@@ -156,9 +153,8 @@ declare void @_ZN4Impl3RunEv(%class.Impl* %this)
   Instruction *Inst = &*++F->front().rbegin();
   auto *CI = dyn_cast<CallInst>(Inst);
   ASSERT_TRUE(CI);
-  CallSite CS(CI);
-  ASSERT_FALSE(CS.getCalledFunction());
-  bool IsPromoted = tryPromoteCall(CS);
+  ASSERT_FALSE(CI->getCalledFunction());
+  bool IsPromoted = tryPromoteCall(*CI);
   EXPECT_FALSE(IsPromoted);
 }
 
@@ -199,9 +195,8 @@ declare void @_ZN4Impl3RunEv(%class.Impl* %this)
   Inst = &*++F->front().rbegin();
   auto *CI = dyn_cast<CallInst>(Inst);
   ASSERT_TRUE(CI);
-  CallSite CS(CI);
-  ASSERT_FALSE(CS.getCalledFunction());
-  bool IsPromoted = tryPromoteCall(CS);
+  ASSERT_FALSE(CI->getCalledFunction());
+  bool IsPromoted = tryPromoteCall(*CI);
   EXPECT_FALSE(IsPromoted);
 }
 
@@ -242,9 +237,8 @@ declare void @_ZN4Impl3RunEv(%class.Impl* %this)
   Inst = &*++F->front().rbegin();
   auto *CI = dyn_cast<CallInst>(Inst);
   ASSERT_TRUE(CI);
-  CallSite CS(CI);
-  ASSERT_FALSE(CS.getCalledFunction());
-  bool IsPromoted = tryPromoteCall(CS);
+  ASSERT_FALSE(CI->getCalledFunction());
+  bool IsPromoted = tryPromoteCall(*CI);
   EXPECT_FALSE(IsPromoted);
 }
 
@@ -302,14 +296,13 @@ declare i32 @_ZN1A3vf2Ev(%struct.A* %this)
   Inst = &*++F->front().rbegin();
   auto *CI = dyn_cast<CallInst>(Inst);
   ASSERT_TRUE(CI);
-  CallSite CS1(CI);
-  ASSERT_FALSE(CS1.getCalledFunction());
-  bool IsPromoted1 = tryPromoteCall(CS1);
+  ASSERT_FALSE(CI->getCalledFunction());
+  bool IsPromoted1 = tryPromoteCall(*CI);
   EXPECT_TRUE(IsPromoted1);
   GV = M->getNamedValue("_ZN1A3vf1Ev");
   ASSERT_TRUE(GV);
   F = dyn_cast<Function>(GV);
-  EXPECT_EQ(F, CS1.getCalledFunction());
+  EXPECT_EQ(F, CI->getCalledFunction());
 
   GV = M->getNamedValue("_Z2g2v");
   ASSERT_TRUE(GV);
@@ -321,14 +314,13 @@ declare i32 @_ZN1A3vf2Ev(%struct.A* %this)
   Inst = &*++F->front().rbegin();
   CI = dyn_cast<CallInst>(Inst);
   ASSERT_TRUE(CI);
-  CallSite CS2(CI);
-  ASSERT_FALSE(CS2.getCalledFunction());
-  bool IsPromoted2 = tryPromoteCall(CS2);
+  ASSERT_FALSE(CI->getCalledFunction());
+  bool IsPromoted2 = tryPromoteCall(*CI);
   EXPECT_TRUE(IsPromoted2);
   GV = M->getNamedValue("_ZN1A3vf2Ev");
   ASSERT_TRUE(GV);
   F = dyn_cast<Function>(GV);
-  EXPECT_EQ(F, CS2.getCalledFunction());
+  EXPECT_EQ(F, CI->getCalledFunction());
 }
 
 // Check that it isn't crashing due to missing promotion legality.
@@ -372,8 +364,7 @@ declare %struct2 @_ZN4Impl3RunEv(%class.Impl* %this)
   Inst = &*++F->front().rbegin();
   auto *CI = dyn_cast<CallInst>(Inst);
   ASSERT_TRUE(CI);
-  CallSite CS(CI);
-  ASSERT_FALSE(CS.getCalledFunction());
-  bool IsPromoted = tryPromoteCall(CS);
+  ASSERT_FALSE(CI->getCalledFunction());
+  bool IsPromoted = tryPromoteCall(*CI);
   EXPECT_FALSE(IsPromoted);
 }


        


More information about the llvm-commits mailing list