[llvm] cea6f4d - [llvm][NFC][CallSite] Remove CallSite from TypeMetadataUtils & related
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 23 08:23:36 PDT 2020
Author: Mircea Trofin
Date: 2020-04-23T08:23:16-07:00
New Revision: cea6f4d5f8431ea723e85f6e57812feb3633ecbe
URL: https://github.com/llvm/llvm-project/commit/cea6f4d5f8431ea723e85f6e57812feb3633ecbe
DIFF: https://github.com/llvm/llvm-project/commit/cea6f4d5f8431ea723e85f6e57812feb3633ecbe.diff
LOG: [llvm][NFC][CallSite] Remove CallSite from TypeMetadataUtils & related
Reviewers: craig.topper, dblaikie
Subscribers: hiraditya, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D78666
Added:
Modified:
llvm/include/llvm/Analysis/TypeMetadataUtils.h
llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
llvm/lib/Analysis/TypeMetadataUtils.cpp
llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/TypeMetadataUtils.h b/llvm/include/llvm/Analysis/TypeMetadataUtils.h
index 43ce26147c2e..ea59c0cbc871 100644
--- a/llvm/include/llvm/Analysis/TypeMetadataUtils.h
+++ b/llvm/include/llvm/Analysis/TypeMetadataUtils.h
@@ -15,7 +15,7 @@
#define LLVM_ANALYSIS_TYPEMETADATAUTILS_H
#include "llvm/ADT/SmallVector.h"
-#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Instructions.h"
namespace llvm {
@@ -33,7 +33,7 @@ struct DevirtCallSite {
/// The offset from the address point to the virtual function.
uint64_t Offset;
/// The call site itself.
- CallSite CS;
+ CallBase &CB;
};
/// Given a call to the intrinsic \@llvm.type.test, find all devirtualizable
diff --git a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
index 1ff47e10bd99..b78115bd43f7 100644
--- a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
+++ b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
@@ -28,7 +28,6 @@
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
@@ -99,7 +98,7 @@ static bool findRefEdges(ModuleSummaryIndex &Index, const User *CurUser,
if (!Visited.insert(U).second)
continue;
- ImmutableCallSite CS(U);
+ const auto *CB = dyn_cast<CallBase>(U);
for (const auto &OI : U->operands()) {
const User *Operand = dyn_cast<User>(OI);
@@ -113,7 +112,7 @@ static bool findRefEdges(ModuleSummaryIndex &Index, const User *CurUser,
// We have a reference to a global value. This should be added to
// the reference set unless it is a callee. Callees are handled
// specially by WriteFunction and are added to a separate list.
- if (!(CS && CS.isCallee(&OI)))
+ if (!(CB && CB->isCallee(&OI)))
RefEdges.insert(Index.getOrInsertValueInfo(GV));
continue;
}
@@ -145,7 +144,7 @@ static void addVCallToSet(DevirtCallSite Call, GlobalValue::GUID Guid,
SetVector<FunctionSummary::ConstVCall> &ConstVCalls) {
std::vector<uint64_t> Args;
// Start from the second argument to skip the "this" pointer.
- for (auto &Arg : make_range(Call.CS.arg_begin() + 1, Call.CS.arg_end())) {
+ for (auto &Arg : make_range(Call.CB.arg_begin() + 1, Call.CB.arg_end())) {
auto *CI = dyn_cast<ConstantInt>(Arg);
if (!CI || CI->getBitWidth() > 64) {
VCalls.insert({Guid, Call.Offset});
@@ -304,8 +303,8 @@ static void computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M,
}
}
findRefEdges(Index, &I, RefEdges, Visited);
- auto CS = ImmutableCallSite(&I);
- if (!CS)
+ const auto *CB = dyn_cast<CallBase>(&I);
+ if (!CB)
continue;
const auto *CI = dyn_cast<CallInst>(&I);
@@ -317,8 +316,8 @@ static void computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M,
if (HasLocalsInUsedOrAsm && CI && CI->isInlineAsm())
HasInlineAsmMaybeReferencingInternal = true;
- auto *CalledValue = CS.getCalledValue();
- auto *CalledFunction = CS.getCalledFunction();
+ auto *CalledValue = CB->getCalledValue();
+ auto *CalledFunction = CB->getCalledFunction();
if (CalledValue && !CalledFunction) {
CalledValue = CalledValue->stripPointerCasts();
// Stripping pointer casts can reveal a called function.
diff --git a/llvm/lib/Analysis/TypeMetadataUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp
index 072d291f3f93..f187d0796771 100644
--- a/llvm/lib/Analysis/TypeMetadataUtils.cpp
+++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp
@@ -37,10 +37,10 @@ findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls,
if (isa<BitCastInst>(User)) {
findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset, CI,
DT);
- } else if (auto CI = dyn_cast<CallInst>(User)) {
- DevirtCalls.push_back({Offset, CI});
- } else if (auto II = dyn_cast<InvokeInst>(User)) {
- DevirtCalls.push_back({Offset, II});
+ } else if (auto *CI = dyn_cast<CallInst>(User)) {
+ DevirtCalls.push_back({Offset, *CI});
+ } else if (auto *II = dyn_cast<InvokeInst>(User)) {
+ DevirtCalls.push_back({Offset, *II});
} else if (HasNonCallUses) {
*HasNonCallUses = true;
}
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 60f8e935ffdd..2e72b2981d83 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -64,7 +64,6 @@
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/Bitcode/BitcodeWriter.h"
-#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugLoc.h"
@@ -354,20 +353,20 @@ namespace {
// A virtual call site. VTable is the loaded virtual table pointer, and CS is
// the indirect virtual call.
struct VirtualCallSite {
- Value *VTable;
- CallSite CS;
+ Value *VTable = nullptr;
+ CallBase &CB;
// If non-null, this field points to the associated unsafe use count stored in
// the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
// of that field for details.
- unsigned *NumUnsafeUses;
+ unsigned *NumUnsafeUses = nullptr;
void
emitRemark(const StringRef OptName, const StringRef TargetName,
function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
- Function *F = CS.getCaller();
- DebugLoc DLoc = CS->getDebugLoc();
- BasicBlock *Block = CS.getParent();
+ Function *F = CB.getCaller();
+ DebugLoc DLoc = CB.getDebugLoc();
+ BasicBlock *Block = CB.getParent();
using namespace ore;
OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block)
@@ -382,12 +381,12 @@ struct VirtualCallSite {
Value *New) {
if (RemarksEnabled)
emitRemark(OptName, TargetName, OREGetter);
- CS->replaceAllUsesWith(New);
- if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
- BranchInst::Create(II->getNormalDest(), CS.getInstruction());
+ CB.replaceAllUsesWith(New);
+ if (auto *II = dyn_cast<InvokeInst>(&CB)) {
+ BranchInst::Create(II->getNormalDest(), &CB);
II->getUnwindDest()->removePredecessor(II->getParent());
}
- CS->eraseFromParent();
+ CB.eraseFromParent();
// This use is no longer unsafe.
if (NumUnsafeUses)
--*NumUnsafeUses;
@@ -460,18 +459,18 @@ struct VTableSlotInfo {
// "this"), grouped by argument list.
std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo;
- void addCallSite(Value *VTable, CallSite CS, unsigned *NumUnsafeUses);
+ void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses);
private:
- CallSiteInfo &findCallSiteInfo(CallSite CS);
+ CallSiteInfo &findCallSiteInfo(CallBase &CB);
};
-CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
+CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) {
std::vector<uint64_t> Args;
- auto *CI = dyn_cast<IntegerType>(CS.getType());
- if (!CI || CI->getBitWidth() > 64 || CS.arg_empty())
+ auto *CBType = dyn_cast<IntegerType>(CB.getType());
+ if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty())
return CSInfo;
- for (auto &&Arg : make_range(CS.arg_begin() + 1, CS.arg_end())) {
+ for (auto &&Arg : make_range(CB.arg_begin() + 1, CB.arg_end())) {
auto *CI = dyn_cast<ConstantInt>(Arg);
if (!CI || CI->getBitWidth() > 64)
return CSInfo;
@@ -480,11 +479,11 @@ CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallSite CS) {
return ConstCSInfo[Args];
}
-void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
+void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB,
unsigned *NumUnsafeUses) {
- auto &CSI = findCallSiteInfo(CS);
+ auto &CSI = findCallSiteInfo(CB);
CSI.AllCallSitesDevirted = false;
- CSI.CallSites.push_back({VTable, CS, NumUnsafeUses});
+ CSI.CallSites.push_back({VTable, CB, NumUnsafeUses});
}
struct DevirtModule {
@@ -1029,8 +1028,8 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo,
if (RemarksEnabled)
VCallSite.emitRemark("single-impl",
TheFn->stripPointerCasts()->getName(), OREGetter);
- VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
- TheFn, VCallSite.CS.getCalledValue()->getType()));
+ VCallSite.CB.setCalledOperand(ConstantExpr::getBitCast(
+ TheFn, VCallSite.CB.getCalledValue()->getType()));
// This use is no longer unsafe.
if (VCallSite.NumUnsafeUses)
--*VCallSite.NumUnsafeUses;
@@ -1253,10 +1252,10 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
if (CSInfo.AllCallSitesDevirted)
return;
for (auto &&VCallSite : CSInfo.CallSites) {
- CallSite CS = VCallSite.CS;
+ CallBase &CB = VCallSite.CB;
// Jump tables are only profitable if the retpoline mitigation is enabled.
- Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features");
+ Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features");
if (FSAttr.hasAttribute(Attribute::None) ||
!FSAttr.getValueAsString().contains("+retpoline"))
continue;
@@ -1269,42 +1268,40 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
// x86_64.
std::vector<Type *> NewArgs;
NewArgs.push_back(Int8PtrTy);
- for (Type *T : CS.getFunctionType()->params())
+ for (Type *T : CB.getFunctionType()->params())
NewArgs.push_back(T);
FunctionType *NewFT =
- FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs,
- CS.getFunctionType()->isVarArg());
+ FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs,
+ CB.getFunctionType()->isVarArg());
PointerType *NewFTPtr = PointerType::getUnqual(NewFT);
- IRBuilder<> IRB(CS.getInstruction());
+ IRBuilder<> IRB(&CB);
std::vector<Value *> Args;
Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
- for (unsigned I = 0; I != CS.getNumArgOperands(); ++I)
- Args.push_back(CS.getArgOperand(I));
+ Args.insert(Args.end(), CB.arg_begin(), CB.arg_end());
- CallSite NewCS;
- if (CS.isCall())
+ CallBase *NewCS = nullptr;
+ if (isa<CallInst>(CB))
NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args);
else
- NewCS = IRB.CreateInvoke(
- NewFT, IRB.CreateBitCast(JT, NewFTPtr),
- cast<InvokeInst>(CS.getInstruction())->getNormalDest(),
- cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args);
- NewCS.setCallingConv(CS.getCallingConv());
+ NewCS = IRB.CreateInvoke(NewFT, IRB.CreateBitCast(JT, NewFTPtr),
+ cast<InvokeInst>(CB).getNormalDest(),
+ cast<InvokeInst>(CB).getUnwindDest(), Args);
+ NewCS->setCallingConv(CB.getCallingConv());
- AttributeList Attrs = CS.getAttributes();
+ AttributeList Attrs = CB.getAttributes();
std::vector<AttributeSet> NewArgAttrs;
NewArgAttrs.push_back(AttributeSet::get(
M.getContext(), ArrayRef<Attribute>{Attribute::get(
M.getContext(), Attribute::Nest)}));
for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I)
NewArgAttrs.push_back(Attrs.getParamAttributes(I));
- NewCS.setAttributes(
+ NewCS->setAttributes(
AttributeList::get(M.getContext(), Attrs.getFnAttributes(),
Attrs.getRetAttributes(), NewArgAttrs));
- CS->replaceAllUsesWith(NewCS.getInstruction());
- CS->eraseFromParent();
+ CB.replaceAllUsesWith(NewCS);
+ CB.eraseFromParent();
// This use is no longer unsafe.
if (VCallSite.NumUnsafeUses)
@@ -1355,7 +1352,7 @@ void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
for (auto Call : CSInfo.CallSites)
Call.replaceAndErase(
"uniform-ret-val", FnName, RemarksEnabled, OREGetter,
- ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal));
+ ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal));
CSInfo.markDevirt();
}
@@ -1461,11 +1458,11 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName,
bool IsOne,
Constant *UniqueMemberAddr) {
for (auto &&Call : CSInfo.CallSites) {
- IRBuilder<> B(Call.CS.getInstruction());
+ IRBuilder<> B(&Call.CB);
Value *Cmp =
B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable,
B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType()));
- Cmp = B.CreateZExt(Cmp, Call.CS->getType());
+ Cmp = B.CreateZExt(Cmp, Call.CB.getType());
Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter,
Cmp);
}
@@ -1529,8 +1526,8 @@ bool DevirtModule::tryUniqueRetValOpt(
void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName,
Constant *Byte, Constant *Bit) {
for (auto Call : CSInfo.CallSites) {
- auto *RetType = cast<IntegerType>(Call.CS.getType());
- IRBuilder<> B(Call.CS.getInstruction());
+ auto *RetType = cast<IntegerType>(Call.CB.getType());
+ IRBuilder<> B(&Call.CB);
Value *Addr =
B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte);
if (RetType->getBitWidth() == 1) {
@@ -1716,7 +1713,7 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc) {
// points to a member of the type identifier %md. Group calls by (type ID,
// offset) pair (effectively the identity of the virtual function) and store
// to CallSlots.
- DenseSet<CallSite> SeenCallSites;
+ DenseSet<CallBase *> SeenCallSites;
for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
I != E;) {
auto CI = dyn_cast<CallInst>(I->getUser());
@@ -1741,8 +1738,8 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc) {
// and we don't want to process call sites multiple times. We can't
// just skip the vtable Ptr if it has been seen before, however, since
// it may be shared by type tests that dominate
diff erent calls.
- if (SeenCallSites.insert(Call.CS).second)
- CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr);
+ if (SeenCallSites.insert(&Call.CB).second)
+ CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr);
}
}
@@ -1828,7 +1825,7 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
if (HasNonCallUses)
++NumUnsafeUses;
for (DevirtCallSite Call : DevirtCalls) {
- CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS,
+ CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB,
&NumUnsafeUses);
}
More information about the llvm-commits
mailing list