[llvm-commits] [poolalloc] r110578 - in /poolalloc/trunk: include/assistDS/Devirt.h lib/AssistDS/Devirt.cpp
John Criswell
criswell at uiuc.edu
Mon Aug 9 10:10:38 PDT 2010
Author: criswell
Date: Mon Aug 9 12:10:38 2010
New Revision: 110578
URL: http://llvm.org/viewvc/llvm-project?rev=110578&view=rev
Log:
Refactoring of Andrew Lenharth's devirtualization pass.
This version works on user-space code and currently devirtualizes all indirect
function calls which for which all function targets are known.
Added:
poolalloc/trunk/include/assistDS/Devirt.h
poolalloc/trunk/lib/AssistDS/Devirt.cpp
- copied, changed from r110449, poolalloc/trunk/lib/AssistDS/SVADevirt.cpp
Added: poolalloc/trunk/include/assistDS/Devirt.h
URL: http://llvm.org/viewvc/llvm-project/poolalloc/trunk/include/assistDS/Devirt.h?rev=110578&view=auto
==============================================================================
--- poolalloc/trunk/include/assistDS/Devirt.h (added)
+++ poolalloc/trunk/include/assistDS/Devirt.h Mon Aug 9 12:10:38 2010
@@ -0,0 +1,72 @@
+//===- Devirt.cpp - Devirtualize using the sig match intrinsic in llva ----===//
+//
+// The LLVM Compiler Infrastructure
+//
+// This file was developed by the LLVM research group and is distributed under
+// the University of Illinois Open Source License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines an LLVM transform that converts indirect function calls
+// into direct function calls.
+//
+//===----------------------------------------------------------------------===//
+
+#include "dsa/CallTargets.h"
+
+#include "llvm/Constants.h"
+#include "llvm/Transforms/IPO.h"
+#include "llvm/Pass.h"
+#include "llvm/Module.h"
+#include "llvm/Function.h"
+#include "llvm/Instructions.h"
+#include "llvm/DerivedTypes.h"
+#include "llvm/Support/InstVisitor.h"
+
+using namespace llvm;
+
+namespace llvm {
+ //
+ // Class: Devirtualize
+ //
+ // Description:
+ // This transform pass will look for indirect function calls and transform
+ // them into a switch statement that selects one of several direct function
+ // calls to execute.
+ //
+ class Devirtualize : public ModulePass, public InstVisitor<Devirtualize> {
+ private:
+ // Access to analysis pass which finds targets of indirect function calls
+ CallTargetFinder* CTF;
+
+ // Worklist of call sites to transform
+ std::vector<Instruction *> Worklist;
+
+ protected:
+ void makeDirectCall (CallSite & CS);
+ Function* buildBounce (CallSite cs,std::vector<const Function*>& Targets);
+
+ public:
+ static char ID;
+ Devirtualize() : ModulePass(&ID), CTF(0) {}
+
+ virtual bool runOnModule(Module & M);
+
+ virtual void getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<CallTargetFinder>();
+ }
+
+ // Visitor methods for analyzing instructions
+ //void visitInstruction(Instruction &I);
+ void visitCallSite(CallSite &CS);
+ void visitCallInst(CallInst &CI) {
+ CallSite CS(&CI);
+ visitCallSite(CS);
+ }
+ void visitInvokeInst(InvokeInst &II) {
+ CallSite CS(&II);
+ visitCallSite(CS);
+ }
+ };
+}
+
Copied: poolalloc/trunk/lib/AssistDS/Devirt.cpp (from r110449, poolalloc/trunk/lib/AssistDS/SVADevirt.cpp)
URL: http://llvm.org/viewvc/llvm-project/poolalloc/trunk/lib/AssistDS/Devirt.cpp?p2=poolalloc/trunk/lib/AssistDS/Devirt.cpp&p1=poolalloc/trunk/lib/AssistDS/SVADevirt.cpp&r1=110449&r2=110578&rev=110578&view=diff
==============================================================================
--- poolalloc/trunk/lib/AssistDS/SVADevirt.cpp (original)
+++ poolalloc/trunk/lib/AssistDS/Devirt.cpp Mon Aug 9 12:10:38 2010
@@ -6,15 +6,11 @@
// the University of Illinois Open Source License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
+
#define DEBUG_TYPE "devirt"
-#include "llvm/Constants.h"
-#include "llvm/Transforms/IPO.h"
-#include "dsa/CallTargets.h"
-#include "llvm/Pass.h"
-#include "llvm/Module.h"
-#include "llvm/Function.h"
-#include "llvm/Instructions.h"
-#include "llvm/DerivedTypes.h"
+
+#include "assistDS/Devirt.h"
+
#include "llvm/Support/CommandLine.h"
#include "llvm/ADT/Statistic.h"
@@ -24,15 +20,23 @@
using namespace llvm;
-#if 0
+// Pass statistics
+STATISTIC(FuncAdded, "Number of bounce functions added");
+STATISTIC(CSConvert, "Number of call sites converted");
+
+// Pass registration
+RegisterPass<Devirtualize>
+X ("devirt", "Devirtualize indirect function calls");
//
// Function: castTo()
//
-// Description: // Given an LLVM value, insert a cast instruction to make it a given type.
+// Description:
+// Given an LLVM value, insert a cast instruction to make it a given type.
//
static inline Value *
-castTo (Value * V, const Type * Ty, Instruction * InsertPt) { //
+castTo (Value * V, const Type * Ty, std::string Name, Instruction * InsertPt) {
+ //
// Don't bother creating a cast if it's already the correct type.
//
if (V->getType() == Ty)
@@ -42,215 +46,259 @@
// If it's a constant, just create a constant expression.
//
if (Constant * C = dyn_cast<Constant>(V)) {
- Constant * CE = ConstantExpr::getCast (C, Ty);
+ Constant * CE = ConstantExpr::getZExtOrBitCast (C, Ty);
return CE;
}
//
// Otherwise, insert a cast instruction.
//
- return new CastInst::Create(V, Ty, "cast", InsertPt);
+ return CastInst::CreateZExtOrBitCast (V, Ty, Name, InsertPt);
}
-namespace {
+//
+// Method: buildBounce()
+//
+// Description:
+// Replaces the given call site with a call to a bounce function. The
+// bounce function compares the function pointer to one of the given
+// target functions and calls the function directly if the pointer
+// matches.
+//
+Function*
+Devirtualize::buildBounce (CallSite CS, std::vector<const Function*>& Targets) {
+ //
+ // Update the statistics on the number of bounce functions added to the
+ // module.
+ //
+ ++FuncAdded;
- static cl::opt<int>
- VirtualLimit("devirt-limit", cl::Hidden, cl::init(16),
- cl::desc("Maximum number of callees to devirtualize at a call site"));
- STATISTIC(FuncAdded, "Number of bounce functions added");
- STATISTIC(CSConvert, "Number of call sites converted");
-
-
- class Devirtualize : public ModulePass {
-
-
- Function * IndirectFuncFail;
-
- std::map<std::pair<const Type*, std::vector<Function*> >, Function*> cache;
- int fnum;
-
- //
- // Method: buildBounce()
- //
- // Description:
- // Replaces the given call site with a call to a bounce function. The
- // bounce function compares the function pointer to one of the given
- // target functions and calls the function directly if the pointer
- // matches.
- Function* buildBounce (CallSite cs,
- std::vector<Function*>& Targets,
- Module& M) {
-
- Value* ptr = cs.getCalledValue();
- const FunctionType* OrigType =
- cast<FunctionType>(cast<PointerType>(ptr->getType())->getElementType());;
- ++FuncAdded;
-
- std::vector< const Type *> TP(OrigType->param_begin(), OrigType->param_end());
- TP.insert(TP.begin(), ptr->getType());
- const FunctionType* NewTy = FunctionType::get(OrigType->getReturnType(), TP, false);
- Function* F = new Function(NewTy, GlobalValue::InternalLinkage, "devirtbounce", &M);
- std::map<Function*, BasicBlock*> targets;
-
- F->arg_begin()->setName("funcPtr");
- std::vector<Value*> fargs;
- for(Function::arg_iterator ai = F->arg_begin(), ae = F->arg_end(); ai != ae; ++ai)
- if (ai != F->arg_begin()) {
- fargs.push_back(ai);
- ai->setName("arg");
- }
-
- for (std::vector<Function*>::iterator i = Targets.begin(), e = Targets.end();
- i != e; ++i) {
- Function* FL = *i;
- BasicBlock* BL = new BasicBlock(FL->getName(), F);
- targets[FL] = BL;
-
- //Make call
- Value* call = new CallInst(FL, fargs, "", BL);
-
- //return correctly
- if (OrigType->getReturnType() == Type::VoidTy)
- new ReturnInst(0, BL);
- else
- new ReturnInst(call, BL);
- }
-
- // Create a set of tests that search for the correct function target
- // and call it directly. If none of the target functions match,
- // call pchk_ind_fail() to note the failure.
-
- //
- // Create the failure basic block. Then, add the following:
- // o the terminating instruction
- // o the indirect call to the original function
- // o a call to phck_ind_fail()
- //
- BasicBlock* tail = new BasicBlock("fail", F, &F->getEntryBlock());
- Instruction * InsertPt;
-#if 0
- InsertPt = new UnreachableInst(tail);
-#else
- Value* p = F->arg_begin();
- Instruction * realCall = new CallInst (p, fargs, "", tail);
- if (OrigType->getReturnType() == Type::VoidTy)
- InsertPt = new ReturnInst(0, tail);
- else
- InsertPt = new ReturnInst(realCall, tail);
-#endif
- Value * FuncVoidPtr = castTo (p,
- PointerType::get(Type::SByteTy),
- realCall);
- new CallInst (IndirectFuncFail, FuncVoidPtr, "", realCall);
-
-
- // Create basic blocks for valid target functions
- for (std::vector<Function*>::iterator i = Targets.begin(), e = Targets.end();
- i != e; ++i) {
- BasicBlock* TB = targets[*i];
- BasicBlock* newB = new BasicBlock("test." + (*i)->getName(), F, &F->getEntryBlock());
- SetCondInst* setcc = new SetCondInst(Instruction::SetEQ, *i, p, "sc", newB);
- new BranchInst(TB, tail, setcc, newB);
- tail = newB;
- }
- return F;
- }
+ //
+ // Create a bounce function that has a function signature almost identical
+ // to the function being called. The only difference is that it will have
+ // an additional pointer argument at the beginning of its argument list that
+ // will be the function to call.
+ //
+ Value* ptr = CS.getCalledValue();
+ const FunctionType* OrigType =
+ cast<FunctionType>(cast<PointerType>(ptr->getType())->getElementType());
+
+ std::vector<const Type *> TP (OrigType->param_begin(), OrigType->param_end());
+ TP.insert (TP.begin(), ptr->getType());
+ const FunctionType* NewTy = FunctionType::get(OrigType->getReturnType(), TP, false);
+ Module * M = CS.getInstruction()->getParent()->getParent()->getParent();
+ Function* F = Function::Create (NewTy,
+ GlobalValue::InternalLinkage,
+ "devirtbounce",
+ M);
- public:
- static char ID;
- Devirtualize() : ModulePass(&ID) {}
-
- virtual bool runOnModule(Module &M) {
- CallTargetFinder* CTF = &getAnalysis<CallTargetFinder>();
-
- // Get references to functions that are needed in the module
- Function* ams = M.getNamedFunction ("llva_assert_match_sig");
- if (!ams)
- return false;
-
- IndirectFuncFail = M.getOrInsertFunction ("pchk_ind_fail",
- Type::VoidTy,
- PointerType::getUnqual(Type::Int8Ty),
- NULL);
-
- std::set<Value*> safecalls;
- std::vector<Instruction*> toDelete;
-
- for (Value::use_iterator ii = ams->use_begin(), ee = ams->use_end();
- ii != ee; ++ii) {
- if (CallInst* CI = dyn_cast<CallInst>(*ii)) {
- std::cerr << "Found safe call site in "
- << CI->getParent()->getParent()->getName() << "\n";
- Value* V = CI->getOperand(1);
- toDelete.push_back(CI);
- do {
- //V->dump();
- safecalls.insert(V);
- if (CastInst* CV = dyn_cast<CastInst>(V))
- V = CV->getOperand(0);
- else V = 0;
- } while (V);
- }
- }
-
- for(std::set<Value*>::iterator i = safecalls.begin(), e = safecalls.end();
- i != e; ++i) {
- for (Value::use_iterator uii = (*i)->use_begin(), uie = (*i)->use_end();
- uii != uie; ++uii) {
- CallSite cs = CallSite::get(*uii);
- bool isSafeCall = cs.getInstruction() &&
- safecalls.find(cs.getCalledValue()) != safecalls.end();
- if (cs.getInstruction() && !cs.getCalledFunction() &&
- (isSafeCall || CTF->isComplete(cs))) {
- std::vector<const Function*> Targets;
- for (std::vector<const Function*>::iterator ii = CTF->begin(cs), ee = CTF->end(cs);
- ii != ee; ++ii)
- if (!isSafeCall || (*ii)->getType() == cs.getCalledValue()->getType())
- Targets.push_back(*ii);
-
- if (Targets.size() > 0) {
- std::cerr << "Target count: " << Targets.size() << " in " << cs.getInstruction()->getParent()->getParent()->getName() << "\n";
- Function* NF = buildBounce(cs, Targets, M);
- if (CallInst* ci = dyn_cast<CallInst>(cs.getInstruction())) {
- ++CSConvert;
- std::vector<Value*> Par(ci->op_begin(), ci->op_end());
- CallInst* cn = CallInst::Create(NF, Par.begin(), Par.end(),
- ci->getName() + ".dv", ci);
- ci->replaceAllUsesWith(cn);
- toDelete.push_back(ci);
- } else if (InvokeInst* ci = dyn_cast<InvokeInst>(cs.getInstruction())) {
- ++CSConvert;
- std::vector<Value*> Par(ci->op_begin(), ci->op_end());
- InvokeInst* cn = InvokeInst::Create(NF, ci->getNormalDest(),
- ci->getUnwindDest(),
- Par, ci->getName()+".dv",
- ci);
- ci->replaceAllUsesWith(cn);
- toDelete.push_back(ci);
- }
- } else //Target size == 0
- std::cerr << "Call site found, but no Targets\n";
- }
- }
- }
-
- bool changed = false;
- for (std::vector<Instruction*>::iterator ii = toDelete.begin(), ee = toDelete.end();
- ii != ee; ++ii) {
- changed = true;
- (*ii)->eraseFromParent();
- }
- return changed;
+ //
+ // Set the names of the arguments. Also, record the arguments in a vector
+ // for subsequence access.
+ //
+ F->arg_begin()->setName("funcPtr");
+ std::vector<Value*> fargs;
+ for(Function::arg_iterator ai = F->arg_begin(), ae = F->arg_end(); ai != ae; ++ai)
+ if (ai != F->arg_begin()) {
+ fargs.push_back(ai);
+ ai->setName("arg");
}
- virtual void getAnalysisUsage(AnalysisUsage &AU) const {
- AU.addRequired<CallTargetFinder>();
+ //
+ // For each function target, create a basic block that will call that
+ // function directly.
+ //
+ std::map<const Function*, BasicBlock*> targets;
+ for (unsigned index = 0; index < Targets.size(); ++index) {
+ const Function* FL = Targets[index];
+
+ // Create the basic block for doing the direct call
+ BasicBlock* BL = BasicBlock::Create (M->getContext(), FL->getName(), F);
+ targets[FL] = BL;
+
+ // Create the direct function call
+ Value* directCall = CallInst::Create ((Value *)FL,
+ fargs.begin(),
+ fargs.end(),
+ "",
+ BL);
+
+ // Add the return instruction for the basic block
+ if (OrigType->getReturnType()->isVoidTy())
+ ReturnInst::Create (M->getContext(), BL);
+ else
+ ReturnInst::Create (M->getContext(), directCall, BL);
+ }
+
+ //
+ // Create a set of tests that search for the correct function target
+ // and call it directly. If none of the target functions match,
+ // abort (or make the result unreachable).
+ //
+
+ //
+ // Create the failure basic block. This basic block should simply be an
+ // unreachable instruction.
+ //
+ BasicBlock* tail = BasicBlock::Create (M->getContext(),
+ "fail",
+ F,
+ &F->getEntryBlock());
+ Instruction * InsertPt;
+ InsertPt = new UnreachableInst (M->getContext(), tail);
+
+ //
+ // Create basic blocks for valid target functions.
+ //
+ for (unsigned index = 0; index < Targets.size(); ++index) {
+ const Function * Target = Targets[index];
+ BasicBlock* TB = targets[Target];
+ BasicBlock* newB = BasicBlock::Create (M->getContext(),
+ "test." + Target->getName(),
+ F,
+ &F->getEntryBlock());
+ CmpInst * setcc = CmpInst::Create (Instruction::ICmp,
+ CmpInst::ICMP_EQ,
+ (Value *) Target,
+ &(*(F->arg_begin())),
+ "sc",
+ newB);
+ BranchInst::Create (TB, tail, setcc, newB);
+ tail = newB;
+ }
+ return F;
+}
+
+//
+// Method: makeDirectCall()
+//
+// Description:
+// Transform the specified call site into a direct call.
+//
+// Inputs:
+// CS - The call site to transform.
+//
+// Preconditions:
+// 1) This method assumes that CS is an indirect call site.
+// 2) This method assumes that a pointer to the CallTarget analysis pass has
+// already been acquired by the class.
+//
+void
+Devirtualize::makeDirectCall (CallSite & CS) {
+ //
+ // Find the targets of the indirect function call.
+ //
+ std::vector<const Function*> Targets;
+ Targets.insert (Targets.begin(), CTF->begin(CS), CTF->end(CS));
+
+ //
+ // Convert the call site if there were any function call targets found.
+ //
+ if (Targets.size() > 0) {
+ //
+ // Build a function which will implement a switch statement. The switch
+ // statement will determine which function target to call and call it.
+ //
+ Function* NF = buildBounce (CS, Targets);
+
+ //
+ // Replace the original call with a call to the bounce function.
+ //
+ if (CallInst* CI = dyn_cast<CallInst>(CS.getInstruction())) {
+ std::vector<Value*> Params (CI->op_begin(), CI->op_end());
+ CallInst* CN = CallInst::Create (NF,
+ Params.begin(),
+ Params.end(),
+ CI->getName() + ".dv",
+ CI);
+ CI->replaceAllUsesWith(CN);
+ CI->eraseFromParent();
+ } else if (InvokeInst* CI = dyn_cast<InvokeInst>(CS.getInstruction())) {
+ std::vector<Value*> Params (CI->op_begin(), CI->op_end());
+ InvokeInst* CN = InvokeInst::Create(NF,
+ CI->getNormalDest(),
+ CI->getUnwindDest(),
+ Params.begin(),
+ Params.end(),
+ CI->getName()+".dv",
+ CI);
+ CI->replaceAllUsesWith(CN);
+ CI->eraseFromParent();
}
- };
+ //
+ // Update the statistics on the number of transformed call sites.
+ //
+ ++CSConvert;
+ }
+
+ return;
+}
+
+//
+// Method: visitCallSite()
+//
+// Description:
+// Examine the specified call site. If it is an indirect call, mark it for
+// transformation into a direct call.
+//
+void
+Devirtualize::visitCallSite (CallSite &CS) {
+ //
+ // First, determine if this is a direct call. If so, then just ignore it.
+ //
+ Value * CalledValue = CS.getCalledValue();
+ if (isa<Function>(CalledValue->stripPointerCasts()))
+ return;
+
+ //
+ // Second, we will only transform those call sites which are complete (i.e.,
+ // for which we know all of the call targets).
+ //
+ if (!(CTF->isComplete(CS)))
+ return;
+
+ //
+ // This is an indirect call site. Put it in the worklist of call sites to
+ // transforms.
+ //
+ Worklist.push_back (CS.getInstruction());
+ return;
+}
- RegisterPass<Devirtualize> X("devirt", "Devirtualization");
+//
+// Method: runOnModule()
+//
+// Description:
+// Entry point for this LLVM transform pass. Look for indirect function calls
+// and turn them into direct function calls.
+//
+bool
+Devirtualize::runOnModule (Module & M) {
+ //
+ // Get the targets of indirect function calls.
+ //
+ CTF = &getAnalysis<CallTargetFinder>();
+ //
+ // Visit all of the call instructions in this function and record those that
+ // are indirect function calls.
+ //
+ visit (M);
+
+ //
+ // Now go through and transform all of the indirect calls that we found that
+ // need transforming.
+ //
+ for (unsigned index = 0; index < Worklist.size(); ++index) {
+ // Autobots, transform (the call site)!
+ CallSite CS (Worklist[index]);
+ makeDirectCall (CS);
+ }
+
+ //
+ // Conservatively assume that we've changed one or more call sites.
+ //
+ return true;
}
-#endif
More information about the llvm-commits
mailing list