[llvm-commits] [poolalloc] r110578 - in /poolalloc/trunk: include/assistDS/Devirt.h lib/AssistDS/Devirt.cpp

John Criswell criswell at uiuc.edu
Mon Aug 9 10:10:38 PDT 2010


Author: criswell
Date: Mon Aug  9 12:10:38 2010
New Revision: 110578

URL: http://llvm.org/viewvc/llvm-project?rev=110578&view=rev
Log:
Refactoring of Andrew Lenharth's devirtualization pass.
This version works on user-space code and currently devirtualizes all indirect
function calls which for which all function targets are known.

Added:
    poolalloc/trunk/include/assistDS/Devirt.h
    poolalloc/trunk/lib/AssistDS/Devirt.cpp
      - copied, changed from r110449, poolalloc/trunk/lib/AssistDS/SVADevirt.cpp

Added: poolalloc/trunk/include/assistDS/Devirt.h
URL: http://llvm.org/viewvc/llvm-project/poolalloc/trunk/include/assistDS/Devirt.h?rev=110578&view=auto
==============================================================================
--- poolalloc/trunk/include/assistDS/Devirt.h (added)
+++ poolalloc/trunk/include/assistDS/Devirt.h Mon Aug  9 12:10:38 2010
@@ -0,0 +1,72 @@
+//===- Devirt.cpp - Devirtualize using the sig match intrinsic in llva ----===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file was developed by the LLVM research group and is distributed under
+// the University of Illinois Open Source License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines an LLVM transform that converts indirect function calls
+// into direct function calls.
+//
+//===----------------------------------------------------------------------===//
+
+#include "dsa/CallTargets.h"
+
+#include "llvm/Constants.h"
+#include "llvm/Transforms/IPO.h"
+#include "llvm/Pass.h"
+#include "llvm/Module.h"
+#include "llvm/Function.h"
+#include "llvm/Instructions.h"
+#include "llvm/DerivedTypes.h"
+#include "llvm/Support/InstVisitor.h"
+
+using namespace llvm;
+
+namespace llvm {
+  //
+  // Class: Devirtualize
+  //
+  // Description:
+  //  This transform pass will look for indirect function calls and transform
+  //  them into a switch statement that selects one of several direct function
+  //  calls to execute.
+  //
+  class Devirtualize : public ModulePass, public InstVisitor<Devirtualize> {
+    private:
+      // Access to analysis pass which finds targets of indirect function calls
+      CallTargetFinder* CTF;
+
+      // Worklist of call sites to transform
+      std::vector<Instruction *> Worklist;
+
+    protected:
+      void makeDirectCall (CallSite & CS);
+      Function* buildBounce (CallSite cs,std::vector<const Function*>& Targets);
+
+    public:
+      static char ID;
+      Devirtualize() : ModulePass(&ID), CTF(0) {}
+
+      virtual bool runOnModule(Module & M);
+
+      virtual void getAnalysisUsage(AnalysisUsage &AU) const {
+        AU.addRequired<CallTargetFinder>();
+      }
+
+      // Visitor methods for analyzing instructions
+      //void visitInstruction(Instruction &I);
+      void visitCallSite(CallSite &CS);
+      void visitCallInst(CallInst &CI) {
+        CallSite CS(&CI);
+        visitCallSite(CS);
+      }
+      void visitInvokeInst(InvokeInst &II) {
+        CallSite CS(&II);
+        visitCallSite(CS);
+      }
+  };
+}
+

Copied: poolalloc/trunk/lib/AssistDS/Devirt.cpp (from r110449, poolalloc/trunk/lib/AssistDS/SVADevirt.cpp)
URL: http://llvm.org/viewvc/llvm-project/poolalloc/trunk/lib/AssistDS/Devirt.cpp?p2=poolalloc/trunk/lib/AssistDS/Devirt.cpp&p1=poolalloc/trunk/lib/AssistDS/SVADevirt.cpp&r1=110449&r2=110578&rev=110578&view=diff
==============================================================================
--- poolalloc/trunk/lib/AssistDS/SVADevirt.cpp (original)
+++ poolalloc/trunk/lib/AssistDS/Devirt.cpp Mon Aug  9 12:10:38 2010
@@ -6,15 +6,11 @@
 // the University of Illinois Open Source License. See LICENSE.TXT for details.
 //
 //===----------------------------------------------------------------------===//
+
 #define DEBUG_TYPE "devirt"
-#include "llvm/Constants.h"
-#include "llvm/Transforms/IPO.h"
-#include "dsa/CallTargets.h"
-#include "llvm/Pass.h"
-#include "llvm/Module.h"
-#include "llvm/Function.h"
-#include "llvm/Instructions.h"
-#include "llvm/DerivedTypes.h"
+
+#include "assistDS/Devirt.h"
+
 #include "llvm/Support/CommandLine.h"
 #include "llvm/ADT/Statistic.h"
 
@@ -24,15 +20,23 @@
 
 using namespace llvm;
 
-#if 0
+// Pass statistics
+STATISTIC(FuncAdded, "Number of bounce functions added");
+STATISTIC(CSConvert, "Number of call sites converted");
+
+// Pass registration
+RegisterPass<Devirtualize>
+X ("devirt", "Devirtualize indirect function calls");
 
 //
 // Function: castTo()
 //
-// Description: //  Given an LLVM value, insert a cast instruction to make it a given type.
+// Description:
+//  Given an LLVM value, insert a cast instruction to make it a given type.
 //
 static inline Value *
-castTo (Value * V, const Type * Ty, Instruction * InsertPt) {   //
+castTo (Value * V, const Type * Ty, std::string Name, Instruction * InsertPt) {
+  //
   // Don't bother creating a cast if it's already the correct type.
   //
   if (V->getType() == Ty)
@@ -42,215 +46,259 @@
   // If it's a constant, just create a constant expression.
   //
   if (Constant * C = dyn_cast<Constant>(V)) {
-    Constant * CE = ConstantExpr::getCast (C, Ty);
+    Constant * CE = ConstantExpr::getZExtOrBitCast (C, Ty);
     return CE;
   }
 
   //
   // Otherwise, insert a cast instruction.
   //
-  return new CastInst::Create(V, Ty, "cast", InsertPt);
+  return CastInst::CreateZExtOrBitCast (V, Ty, Name, InsertPt);
 }
 
-namespace {
+//
+// Method: buildBounce()
+//
+// Description:
+//  Replaces the given call site with a call to a bounce function.  The
+//  bounce function compares the function pointer to one of the given
+//  target functions and calls the function directly if the pointer
+//  matches.
+//
+Function*
+Devirtualize::buildBounce (CallSite CS, std::vector<const Function*>& Targets) {
+  //
+  // Update the statistics on the number of bounce functions added to the
+  // module.
+  //
+  ++FuncAdded;
 
-  static cl::opt<int>
-  VirtualLimit("devirt-limit", cl::Hidden, cl::init(16),
-               cl::desc("Maximum number of callees to devirtualize at a call site"));
-  STATISTIC(FuncAdded, "Number of bounce functions added");
-  STATISTIC(CSConvert, "Number of call sites converted");
-
-
-  class Devirtualize : public ModulePass {
-
-
-    Function * IndirectFuncFail;
-
-    std::map<std::pair<const Type*, std::vector<Function*> >, Function*> cache;
-    int fnum;
-
-    //
-    // Method: buildBounce()
-    //
-    // Description:
-    //  Replaces the given call site with a call to a bounce function.  The
-    //  bounce function compares the function pointer to one of the given
-    //  target functions and calls the function directly if the pointer
-    //  matches.
-    Function* buildBounce (CallSite cs,
-                           std::vector<Function*>& Targets,
-                           Module& M) {
-
-      Value* ptr = cs.getCalledValue();
-      const FunctionType* OrigType = 
-        cast<FunctionType>(cast<PointerType>(ptr->getType())->getElementType());;
-      ++FuncAdded;
-
-      std::vector< const Type *> TP(OrigType->param_begin(), OrigType->param_end());
-      TP.insert(TP.begin(), ptr->getType());
-      const FunctionType* NewTy = FunctionType::get(OrigType->getReturnType(), TP, false);
-      Function* F = new Function(NewTy, GlobalValue::InternalLinkage, "devirtbounce", &M);
-      std::map<Function*, BasicBlock*> targets;
-
-      F->arg_begin()->setName("funcPtr");
-      std::vector<Value*> fargs;
-      for(Function::arg_iterator ai = F->arg_begin(), ae = F->arg_end(); ai != ae; ++ai)
-        if (ai != F->arg_begin()) {
-          fargs.push_back(ai);
-          ai->setName("arg");
-        }
-
-      for (std::vector<Function*>::iterator i = Targets.begin(), e = Targets.end();
-           i != e; ++i) {
-        Function* FL = *i;
-        BasicBlock* BL = new BasicBlock(FL->getName(), F);
-        targets[FL] = BL;
-
-        //Make call
-        Value* call = new CallInst(FL, fargs, "", BL);
-
-        //return correctly
-        if (OrigType->getReturnType() == Type::VoidTy)
-          new ReturnInst(0, BL);
-        else
-          new ReturnInst(call, BL);
-      }
-
-      // Create a set of tests that search for the correct function target
-      // and call it directly.  If none of the target functions match,
-      // call pchk_ind_fail() to note the failure.
-
-      //
-      // Create the failure basic block.  Then, add the following:
-      //  o the terminating instruction
-      //  o the indirect call to the original function
-      //  o a call to phck_ind_fail()
-      //
-      BasicBlock* tail = new BasicBlock("fail", F, &F->getEntryBlock());
-      Instruction * InsertPt;
-#if 0
-      InsertPt = new UnreachableInst(tail);
-#else
-      Value* p = F->arg_begin();
-      Instruction * realCall = new CallInst (p, fargs, "", tail);
-      if (OrigType->getReturnType() == Type::VoidTy)
-        InsertPt = new ReturnInst(0, tail);
-      else
-        InsertPt = new ReturnInst(realCall, tail);
-#endif
-      Value * FuncVoidPtr = castTo (p,
-                                    PointerType::get(Type::SByteTy),
-                                    realCall);
-      new CallInst (IndirectFuncFail, FuncVoidPtr, "", realCall);
-      
-
-      // Create basic blocks for valid target functions
-      for (std::vector<Function*>::iterator i = Targets.begin(), e = Targets.end();
-           i != e; ++i) {
-        BasicBlock* TB = targets[*i];
-        BasicBlock* newB = new BasicBlock("test." + (*i)->getName(), F, &F->getEntryBlock());
-        SetCondInst* setcc = new SetCondInst(Instruction::SetEQ, *i, p, "sc", newB);
-        new BranchInst(TB, tail, setcc, newB);
-        tail = newB;
-      }
-      return F;
-    }
+  //
+  // Create a bounce function that has a function signature almost identical
+  // to the function being called.  The only difference is that it will have
+  // an additional pointer argument at the beginning of its argument list that
+  // will be the function to call.
+  //
+  Value* ptr = CS.getCalledValue();
+  const FunctionType* OrigType = 
+    cast<FunctionType>(cast<PointerType>(ptr->getType())->getElementType());
+
+  std::vector<const Type *> TP (OrigType->param_begin(), OrigType->param_end());
+  TP.insert (TP.begin(), ptr->getType());
+  const FunctionType* NewTy = FunctionType::get(OrigType->getReturnType(), TP, false);
+  Module * M = CS.getInstruction()->getParent()->getParent()->getParent();
+  Function* F = Function::Create (NewTy,
+                                  GlobalValue::InternalLinkage,
+                                  "devirtbounce",
+                                  M);
 
-  public:
-    static char ID;
-    Devirtualize() : ModulePass(&ID) {}
-
-    virtual bool runOnModule(Module &M) {
-      CallTargetFinder* CTF = &getAnalysis<CallTargetFinder>();
-
-      // Get references to functions that are needed in the module
-      Function* ams = M.getNamedFunction ("llva_assert_match_sig");
-      if (!ams)
-        return false;
-
-      IndirectFuncFail = M.getOrInsertFunction ("pchk_ind_fail",
-                                                Type::VoidTy,
-                                                PointerType::getUnqual(Type::Int8Ty),
-                                                NULL);
-      
-      std::set<Value*> safecalls;
-      std::vector<Instruction*> toDelete;
-
-      for (Value::use_iterator ii = ams->use_begin(), ee = ams->use_end();
-           ii != ee; ++ii) {
-        if (CallInst* CI = dyn_cast<CallInst>(*ii)) {
-          std::cerr << "Found safe call site in " 
-                    << CI->getParent()->getParent()->getName() << "\n";
-          Value* V = CI->getOperand(1);
-          toDelete.push_back(CI);
-          do {
-            //V->dump();
-            safecalls.insert(V);
-            if (CastInst* CV = dyn_cast<CastInst>(V))
-              V = CV->getOperand(0);
-            else V = 0;
-          } while (V);
-        }
-      }
-
-      for(std::set<Value*>::iterator i = safecalls.begin(), e = safecalls.end();
-          i != e; ++i) {
-        for (Value::use_iterator uii = (*i)->use_begin(), uie = (*i)->use_end();
-             uii != uie; ++uii) {
-          CallSite cs = CallSite::get(*uii);
-          bool isSafeCall = cs.getInstruction() && 
-            safecalls.find(cs.getCalledValue()) != safecalls.end();
-          if (cs.getInstruction() && !cs.getCalledFunction() &&
-              (isSafeCall || CTF->isComplete(cs))) {
-            std::vector<const Function*> Targets;
-            for (std::vector<const Function*>::iterator ii = CTF->begin(cs), ee = CTF->end(cs);
-                 ii != ee; ++ii)
-              if (!isSafeCall || (*ii)->getType() == cs.getCalledValue()->getType())
-                Targets.push_back(*ii);
-
-            if (Targets.size() > 0) {
-              std::cerr << "Target count: " << Targets.size() << " in " << cs.getInstruction()->getParent()->getParent()->getName() << "\n";
-              Function* NF = buildBounce(cs, Targets, M);
-              if (CallInst* ci = dyn_cast<CallInst>(cs.getInstruction())) {
-                ++CSConvert;
-                std::vector<Value*> Par(ci->op_begin(), ci->op_end());
-                CallInst* cn = CallInst::Create(NF, Par.begin(), Par.end(),
-                                                ci->getName() + ".dv", ci);
-                ci->replaceAllUsesWith(cn);
-                toDelete.push_back(ci);
-              } else if (InvokeInst* ci = dyn_cast<InvokeInst>(cs.getInstruction())) {
-                ++CSConvert;
-                std::vector<Value*> Par(ci->op_begin(), ci->op_end());
-                InvokeInst* cn = InvokeInst::Create(NF, ci->getNormalDest(),
-                                                    ci->getUnwindDest(),
-                                                    Par, ci->getName()+".dv",
-                                                    ci);
-                ci->replaceAllUsesWith(cn);
-                toDelete.push_back(ci);
-              }
-            } else //Target size == 0
-              std::cerr << "Call site found, but no Targets\n";
-          }
-        }
-      }
-
-      bool changed = false;
-      for (std::vector<Instruction*>::iterator ii = toDelete.begin(), ee = toDelete.end();
-           ii != ee; ++ii) {
-        changed = true;
-        (*ii)->eraseFromParent();
-      }
-      return changed;
+  //
+  // Set the names of the arguments.  Also, record the arguments in a vector
+  // for subsequence access.
+  //
+  F->arg_begin()->setName("funcPtr");
+  std::vector<Value*> fargs;
+  for(Function::arg_iterator ai = F->arg_begin(), ae = F->arg_end(); ai != ae; ++ai)
+    if (ai != F->arg_begin()) {
+      fargs.push_back(ai);
+      ai->setName("arg");
     }
 
-    virtual void getAnalysisUsage(AnalysisUsage &AU) const {
-      AU.addRequired<CallTargetFinder>();
+  //
+  // For each function target, create a basic block that will call that
+  // function directly.
+  //
+  std::map<const Function*, BasicBlock*> targets;
+  for (unsigned index = 0; index < Targets.size(); ++index) {
+    const Function* FL = Targets[index];
+
+    // Create the basic block for doing the direct call
+    BasicBlock* BL = BasicBlock::Create (M->getContext(), FL->getName(), F);
+    targets[FL] = BL;
+
+    // Create the direct function call
+    Value* directCall = CallInst::Create ((Value *)FL,
+                                          fargs.begin(),
+                                          fargs.end(),
+                                          "",
+                                          BL);
+
+    // Add the return instruction for the basic block
+    if (OrigType->getReturnType()->isVoidTy())
+      ReturnInst::Create (M->getContext(), BL);
+    else
+      ReturnInst::Create (M->getContext(), directCall, BL);
+  }
+
+  //
+  // Create a set of tests that search for the correct function target
+  // and call it directly.  If none of the target functions match,
+  // abort (or make the result unreachable).
+  //
+
+  //
+  // Create the failure basic block.  This basic block should simply be an
+  // unreachable instruction.
+  //
+  BasicBlock* tail = BasicBlock::Create (M->getContext(),
+                                         "fail",
+                                         F,
+                                         &F->getEntryBlock());
+  Instruction * InsertPt;
+  InsertPt = new UnreachableInst (M->getContext(), tail);
+
+  //
+  // Create basic blocks for valid target functions.
+  //
+  for (unsigned index = 0; index < Targets.size(); ++index) {
+    const Function * Target = Targets[index];
+    BasicBlock* TB = targets[Target];
+    BasicBlock* newB = BasicBlock::Create (M->getContext(),
+                                           "test." + Target->getName(),
+                                           F,
+                                           &F->getEntryBlock());
+    CmpInst * setcc = CmpInst::Create (Instruction::ICmp,
+                                       CmpInst::ICMP_EQ,
+                                       (Value *) Target,
+                                       &(*(F->arg_begin())),
+                                       "sc",
+                                       newB);
+    BranchInst::Create (TB, tail, setcc, newB);
+    tail = newB;
+  }
+  return F;
+}
+
+//
+// Method: makeDirectCall()
+//
+// Description:
+//  Transform the specified call site into a direct call.
+//
+// Inputs:
+//  CS - The call site to transform.
+//
+// Preconditions:
+//  1) This method assumes that CS is an indirect call site.
+//  2) This method assumes that a pointer to the CallTarget analysis pass has
+//     already been acquired by the class.
+//
+void
+Devirtualize::makeDirectCall (CallSite & CS) {
+  //
+  // Find the targets of the indirect function call.
+  //
+  std::vector<const Function*> Targets;
+  Targets.insert (Targets.begin(), CTF->begin(CS), CTF->end(CS));
+
+  //
+  // Convert the call site if there were any function call targets found.
+  //
+  if (Targets.size() > 0) {
+    //
+    // Build a function which will implement a switch statement.  The switch
+    // statement will determine which function target to call and call it.
+    //
+    Function* NF = buildBounce (CS, Targets);
+
+    //
+    // Replace the original call with a call to the bounce function.
+    //
+    if (CallInst* CI = dyn_cast<CallInst>(CS.getInstruction())) {
+      std::vector<Value*> Params (CI->op_begin(), CI->op_end());
+      CallInst* CN = CallInst::Create (NF,
+                                       Params.begin(),
+                                       Params.end(),
+                                       CI->getName() + ".dv",
+                                       CI);
+      CI->replaceAllUsesWith(CN);
+      CI->eraseFromParent();
+    } else if (InvokeInst* CI = dyn_cast<InvokeInst>(CS.getInstruction())) {
+      std::vector<Value*> Params (CI->op_begin(), CI->op_end());
+      InvokeInst* CN = InvokeInst::Create(NF,
+                                          CI->getNormalDest(),
+                                          CI->getUnwindDest(),
+                                          Params.begin(),
+                                          Params.end(),
+                                          CI->getName()+".dv",
+                                          CI);
+      CI->replaceAllUsesWith(CN);
+      CI->eraseFromParent();
     }
 
-  };
+    //
+    // Update the statistics on the number of transformed call sites.
+    //
+    ++CSConvert;
+  }
+
+  return;
+}
+
+//
+// Method: visitCallSite()
+//
+// Description:
+//  Examine the specified call site.  If it is an indirect call, mark it for
+//  transformation into a direct call.
+//
+void
+Devirtualize::visitCallSite (CallSite &CS) {
+  //
+  // First, determine if this is a direct call.  If so, then just ignore it.
+  //
+  Value * CalledValue = CS.getCalledValue();
+  if (isa<Function>(CalledValue->stripPointerCasts()))
+    return;
+
+  //
+  // Second, we will only transform those call sites which are complete (i.e.,
+  // for which we know all of the call targets).
+  //
+  if (!(CTF->isComplete(CS)))
+    return;
+
+  //
+  // This is an indirect call site.  Put it in the worklist of call sites to
+  // transforms.
+  //
+  Worklist.push_back (CS.getInstruction());
+  return;
+}
 
-  RegisterPass<Devirtualize> X("devirt", "Devirtualization");
+//
+// Method: runOnModule()
+//
+// Description:
+//  Entry point for this LLVM transform pass.  Look for indirect function calls
+//  and turn them into direct function calls.
+//
+bool
+Devirtualize::runOnModule (Module & M) {
+  //
+  // Get the targets of indirect function calls.
+  //
+  CTF = &getAnalysis<CallTargetFinder>();
 
+  //
+  // Visit all of the call instructions in this function and record those that
+  // are indirect function calls.
+  //
+  visit (M);
+
+  //
+  // Now go through and transform all of the indirect calls that we found that
+  // need transforming.
+  //
+  for (unsigned index = 0; index < Worklist.size(); ++index) {
+    // Autobots, transform (the call site)!
+    CallSite CS (Worklist[index]);
+    makeDirectCall (CS);
+  }
+
+  //
+  // Conservatively assume that we've changed one or more call sites.
+  //
+  return true;
 }
 
-#endif





More information about the llvm-commits mailing list