[llvm-commits] [poolalloc] r133783 - in /poolalloc/trunk: include/assistDS/TypeChecks.h lib/AssistDS/TypeChecks.cpp

Arushi Aggarwal aggarwa4 at illinois.edu
Thu Jun 23 18:53:30 PDT 2011


Author: aggarwa4
Date: Thu Jun 23 20:53:30 2011
New Revision: 133783

URL: http://llvm.org/viewvc/llvm-project?rev=133783&view=rev
Log:
Some refactoring.
Visit basic blocks in dominator tree order, when trying
to delete checks.

Modified:
    poolalloc/trunk/include/assistDS/TypeChecks.h
    poolalloc/trunk/lib/AssistDS/TypeChecks.cpp

Modified: poolalloc/trunk/include/assistDS/TypeChecks.h
URL: http://llvm.org/viewvc/llvm-project/poolalloc/trunk/include/assistDS/TypeChecks.h?rev=133783&r1=133782&r2=133783&view=diff
==============================================================================
--- poolalloc/trunk/include/assistDS/TypeChecks.h (original)
+++ poolalloc/trunk/include/assistDS/TypeChecks.h Thu Jun 23 20:53:30 2011
@@ -55,6 +55,7 @@
   bool initShadow(Module &M);
   void addTypeMap(Module &M) ;
   void optimizeChecks(Module &M);
+  void initRuntimeCheckPrototypes(Module &M);
   
   bool visitMain(Module &M, Function &F); 
 

Modified: poolalloc/trunk/lib/AssistDS/TypeChecks.cpp
URL: http://llvm.org/viewvc/llvm-project/poolalloc/trunk/lib/AssistDS/TypeChecks.cpp?rev=133783&r1=133782&r2=133783&view=diff
==============================================================================
--- poolalloc/trunk/lib/AssistDS/TypeChecks.cpp (original)
+++ poolalloc/trunk/lib/AssistDS/TypeChecks.cpp Thu Jun 23 20:53:30 2011
@@ -26,6 +26,7 @@
 
 #include <set>
 #include <vector>
+#include <deque>
 
 using namespace llvm;
 
@@ -66,6 +67,7 @@
 
 static Constant *One = 0;
 static Constant *Zero = 0;
+
 static Constant *RegisterArgv;
 static Constant *RegisterEnvp;
 
@@ -96,6 +98,7 @@
   if(UsedTypes.find(Ty) == UsedTypes.end())
     UsedTypes[Ty] = UsedTypes.size();
 
+  assert((UsedTypes.size() < 254) && "Too many types found. Not enough metadata bits");
   return UsedTypes[Ty];
 }
 
@@ -142,102 +145,8 @@
   One = ConstantInt::get(Int64Ty, 1);
   Zero = ConstantInt::get(Int64Ty, 0);
 
-  RegisterArgv = M.getOrInsertFunction("trackArgvType",
-                                       VoidTy,
-                                       Int32Ty, /*argc */
-                                       VoidPtrTy->getPointerTo(),/*argv*/
-                                       NULL);
-  RegisterEnvp = M.getOrInsertFunction("trackEnvpType",
-                                       VoidTy,
-                                       VoidPtrTy->getPointerTo(),/*envp*/
-                                       NULL);
-  trackGlobal = M.getOrInsertFunction("trackGlobal",
-                                      VoidTy,
-                                      VoidPtrTy,/*ptr*/
-                                      TypeTagTy,/*type*/
-                                      Int64Ty,/*size*/
-                                      Int32Ty,/*tag*/
-                                      NULL);
-  trackArray = M.getOrInsertFunction("trackArray",
-                                     VoidTy,
-                                     VoidPtrTy,/*ptr*/
-                                     Int64Ty,/*size*/
-                                     Int64Ty,/*count*/
-                                     Int32Ty,/*tag*/
-                                     NULL);
-  trackInitInst = M.getOrInsertFunction("trackInitInst",
-                                        VoidTy,
-                                        VoidPtrTy,/*ptr*/
-                                        Int64Ty,/*size*/
-                                        Int32Ty,/*tag*/
-                                        NULL);
-  trackUnInitInst = M.getOrInsertFunction("trackUnInitInst",
-                                          VoidTy,
-                                          VoidPtrTy,/*ptr*/
-                                          Int64Ty,/*size*/
-                                          Int32Ty,/*tag*/
-                                          NULL);
-  trackStoreInst = M.getOrInsertFunction("trackStoreInst",
-                                         VoidTy,
-                                         VoidPtrTy,/*ptr*/
-                                         TypeTagTy,/*type*/
-                                         Int64Ty,/*size*/
-                                         Int32Ty,/*tag*/
-                                         NULL);
-  getTypeTag = M.getOrInsertFunction("getTypeTag",
-                                     VoidTy,
-                                     VoidPtrTy, /*ptr*/
-                                     Int64Ty, /*size*/
-                                     TypeTagPtrTy, /*dest for type tag*/
-                                     Int32Ty, /*tag*/
-                                     NULL);
-  checkTypeInst = M.getOrInsertFunction("checkType",
-                                        VoidTy,
-                                        TypeTagTy,/*type*/
-                                        Int64Ty,/*size*/
-                                        TypeTagPtrTy,/*ptr to metadata*/
-                                        VoidPtrTy,/*ptr*/
-                                        Int32Ty,/*tag*/
-                                        NULL);
-  setTypeInfo = M.getOrInsertFunction("setTypeInfo",
-                                       VoidTy,
-                                       VoidPtrTy,/*dest ptr*/
-                                       TypeTagPtrTy,/*metadata*/
-                                       Int64Ty,/*size*/
-                                       Int32Ty,/*tag*/
-                                       NULL);
-  copyTypeInfo = M.getOrInsertFunction("copyTypeInfo",
-                                       VoidTy,
-                                       VoidPtrTy,/*dest ptr*/
-                                       VoidPtrTy,/*src ptr*/
-                                       Int64Ty,/*size*/
-                                       Int32Ty,/*tag*/
-                                       NULL);
-  trackStringInput = M.getOrInsertFunction("trackStringInput",
-                                           VoidTy,
-                                           VoidPtrTy,
-                                           Int32Ty,
-                                           NULL);
-  setVAInfo = M.getOrInsertFunction("setVAInfo",
-                                    VoidTy,
-                                    VoidPtrTy,/*va_list ptr*/
-                                    Int64Ty,/*total num of elements in va_list */
-                                    TypeTagPtrTy,/*ptr to metadta*/
-                                    Int32Ty,/*tag*/
-                                    NULL);
-  copyVAInfo = M.getOrInsertFunction("copyVAInfo",
-                                     VoidTy,
-                                     VoidPtrTy,/*dst va_list*/
-                                     VoidPtrTy,/*src va_list */
-                                     Int32Ty,/*tag*/
-                                     NULL);
-  checkVAArg = M.getOrInsertFunction("checkVAArgType",
-                                     VoidTy,
-                                     VoidPtrTy,/*va_list ptr*/
-                                     TypeTagTy,/*type*/
-                                     Int32Ty,/*tag*/
-                                     NULL);
-
+  // Add prototypes for the dynamic type checking functions
+  initRuntimeCheckPrototypes(M);
 
   UsedTypes.clear(); // Reset if run multiple times.
   VAArgFunctions.clear();
@@ -254,7 +163,7 @@
   // Insert the shadow initialization function.
   modified |= initShadow(M);
 
-  // Record argv
+  // Record argv/envp
   modified |= visitMain(M, *MainF);
 
   // Recognize special cases
@@ -298,6 +207,7 @@
     modified |= visitByValFunction(M, *F);
   }
 
+  // Modify all the var arg functions
   while(!VAArgFunctions.empty()) {
     Function *F = VAArgFunctions.back();
     VAArgFunctions.pop_back();
@@ -305,6 +215,7 @@
     modified |= visitVarArgFunction(M, *F);
   }
 
+  // Modify all the address taken functions
   while(!AddressTakenFunctions.empty()) {
     Function *F = AddressTakenFunctions.back();
     AddressTakenFunctions.pop_back();
@@ -318,8 +229,8 @@
     if(F.isDeclaration())
       continue;
 
-    // Loop over all of the instructions in the function, 
-    // adding their return type as well as the types of their operands.
+    // Loop over all of the instructions in the function,
+    // adding instrumentation where needed.
     for (inst_iterator II = inst_begin(F), IE = inst_end(F); II != IE;++II) {
       Instruction &I = *II;
       if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
@@ -339,7 +250,6 @@
     }
   }
 
-  // visit all the uses of the address taken functions and modify if
   // visit all the indirect call sites
   std::set<Instruction*>::iterator II = IndCalls.begin();
   for(; II != IndCalls.end();) {
@@ -347,6 +257,8 @@
     modified |= visitIndirectCallSite(M,I);
   }
 
+  // visit all the uses of the address taken functions and modify if
+  // not being passed to external code
   std::map<Function *, Function * >::iterator FI = IndFunctionsMap.begin(), FE = IndFunctionsMap.end();
   for(;FI!=FE;++FI) {
     Function *F = FI->first;
@@ -404,6 +316,8 @@
     }
   }
 
+  // remove redundant checks, caused due to insturmenting uses of loads
+  // Remove a check if it is dominated by another check for the same instruction
   optimizeChecks(M);
 
   // add a global that contains the mapping from metadata to strings
@@ -415,7 +329,118 @@
   return modified;
 }
 
+void TypeChecks::initRuntimeCheckPrototypes(Module &M) {
+  
+  RegisterArgv = M.getOrInsertFunction("trackArgvType",
+                                       VoidTy,
+                                       Int32Ty, /*argc */
+                                       VoidPtrTy->getPointerTo(),/*argv*/
+                                       NULL);
+
+  RegisterEnvp = M.getOrInsertFunction("trackEnvpType",
+                                       VoidTy,
+                                       VoidPtrTy->getPointerTo(),/*envp*/
+                                       NULL);
+  
+  trackGlobal = M.getOrInsertFunction("trackGlobal",
+                                      VoidTy,
+                                      VoidPtrTy,/*ptr*/
+                                      TypeTagTy,/*type*/
+                                      Int64Ty,/*size*/
+                                      Int32Ty,/*tag*/
+                                      NULL);
+
+  trackArray = M.getOrInsertFunction("trackArray",
+                                     VoidTy,
+                                     VoidPtrTy,/*ptr*/
+                                     Int64Ty,/*size*/
+                                     Int64Ty,/*count*/
+                                     Int32Ty,/*tag*/
+                                     NULL);
+  
+  trackInitInst = M.getOrInsertFunction("trackInitInst",
+                                        VoidTy,
+                                        VoidPtrTy,/*ptr*/
+                                        Int64Ty,/*size*/
+                                        Int32Ty,/*tag*/
+                                        NULL);
+
+  trackUnInitInst = M.getOrInsertFunction("trackUnInitInst",
+                                          VoidTy,
+                                          VoidPtrTy,/*ptr*/
+                                          Int64Ty,/*size*/
+                                          Int32Ty,/*tag*/
+                                          NULL);
+
+  trackStoreInst = M.getOrInsertFunction("trackStoreInst",
+                                         VoidTy,
+                                         VoidPtrTy,/*ptr*/
+                                         TypeTagTy,/*type*/
+                                         Int64Ty,/*size*/
+                                         Int32Ty,/*tag*/
+                                         NULL);
+  getTypeTag = M.getOrInsertFunction("getTypeTag",
+                                     VoidTy,
+                                     VoidPtrTy, /*ptr*/
+                                     Int64Ty, /*size*/
+                                     TypeTagPtrTy, /*dest for type tag*/
+                                     Int32Ty, /*tag*/
+                                     NULL);
+  checkTypeInst = M.getOrInsertFunction("checkType",
+                                        VoidTy,
+                                        TypeTagTy,/*type*/
+                                        Int64Ty,/*size*/
+                                        TypeTagPtrTy,/*ptr to metadata*/
+                                        VoidPtrTy,/*ptr*/
+                                        Int32Ty,/*tag*/
+                                        NULL);
+  setTypeInfo = M.getOrInsertFunction("setTypeInfo",
+                                       VoidTy,
+                                       VoidPtrTy,/*dest ptr*/
+                                       TypeTagPtrTy,/*metadata*/
+                                       Int64Ty,/*size*/
+                                       Int32Ty,/*tag*/
+                                       NULL);
+  copyTypeInfo = M.getOrInsertFunction("copyTypeInfo",
+                                       VoidTy,
+                                       VoidPtrTy,/*dest ptr*/
+                                       VoidPtrTy,/*src ptr*/
+                                       Int64Ty,/*size*/
+                                       Int32Ty,/*tag*/
+                                       NULL);
+  trackStringInput = M.getOrInsertFunction("trackStringInput",
+                                           VoidTy,
+                                           VoidPtrTy,
+                                           Int32Ty,
+                                           NULL);
+  setVAInfo = M.getOrInsertFunction("setVAInfo",
+                                    VoidTy,
+                                    VoidPtrTy,/*va_list ptr*/
+                                    Int64Ty,/*total num of elements in va_list */
+                                    TypeTagPtrTy,/*ptr to metadta*/
+                                    Int32Ty,/*tag*/
+                                    NULL);
+  copyVAInfo = M.getOrInsertFunction("copyVAInfo",
+                                     VoidTy,
+                                     VoidPtrTy,/*dst va_list*/
+                                     VoidPtrTy,/*src va_list */
+                                     Int32Ty,/*tag*/
+                                     NULL);
+  checkVAArg = M.getOrInsertFunction("checkVAArgType",
+                                     VoidTy,
+                                     VoidPtrTy,/*va_list ptr*/
+                                     TypeTagTy,/*type*/
+                                     Int32Ty,/*tag*/
+                                     NULL);
+
+}
+
+// Delete checks, if it is dominated by another check for the same value.
+// We might get multiple checks on a path, if there are multiple uses of
+// a load inst.
+/*
 void TypeChecks::optimizeChecks(Module &M) {
+  // TODO: visit in dominator tree order
   for (Module::iterator MI = M.begin(), ME = M.end(); MI != ME; ++MI) {
     Function &F = *MI;
     if(F.isDeclaration())
@@ -435,13 +460,15 @@
             continue;
           if(CI2->getParent()->getParent() != &F)
             continue;
+          // Check that they are refering to the same pointer
           if(CI->getOperand(4) != CI2->getOperand(4))
             continue;
+          // Check that they are using the same metadata for comparison.
           if(CI->getOperand(3) != CI2->getOperand(3))
             continue;
+          // if CI, dominates CI2, delete CI2
           if(!DT.dominates(CI, CI2))
             continue;
-          CI->dump();
           CI2->dump();
           toDelete.push_back(CI2);
         }
@@ -453,7 +480,53 @@
       }
     }
   }
+}*/
 
+void TypeChecks::optimizeChecks(Module &M) {
+  for (Module::iterator MI = M.begin(), ME = M.end(); MI != ME; ++MI) {
+    Function &F = *MI;
+    if(F.isDeclaration())
+      continue;
+    DominatorTree & DT = getAnalysis<DominatorTree>(F);
+    std::deque<DomTreeNode *> Worklist;
+    Worklist.push_back (DT.getRootNode());
+    while(Worklist.size()) {
+      DomTreeNode * Node = Worklist.front();
+      Worklist.pop_front();
+      BasicBlock *BB = Node->getBlock();
+      for (BasicBlock::iterator bi = BB->begin(); bi != BB->end(); ++bi) {
+        CallInst *CI = dyn_cast<CallInst>(bi);
+        if(!CI)
+          continue;
+        if(CI->getCalledFunction() != checkTypeInst)
+          continue;
+        std::list<Instruction *>toDelete;
+        for(Value::use_iterator User = checkTypeInst->use_begin(); User != checkTypeInst->use_end(); ++User) {
+          CallInst *CI2 = dyn_cast<CallInst>(User);
+          if(CI2 == CI)
+            continue;
+          if(CI2->getParent()->getParent() != &F)
+            continue;
+          // Check that they are refering to the same pointer
+          if(CI->getOperand(4) != CI2->getOperand(4))
+            continue;
+          // Check that they are using the same metadata for comparison.
+          if(CI->getOperand(3) != CI2->getOperand(3))
+            continue;
+          // if CI, dominates CI2, delete CI2
+          if(!DT.dominates(CI, CI2))
+            continue;
+          toDelete.push_back(CI2);
+        }
+        while(!toDelete.empty()) {
+          Instruction *I = toDelete.back();
+          toDelete.pop_back();
+          I->eraseFromParent();
+        }
+      }
+      Worklist.insert(Worklist.end(), Node->begin(), Node->end());
+    }
+  }
 }
 
 // add a global that has the metadata -> typeString mapping
@@ -516,6 +589,9 @@
                     );
 }
 
+// For each address taken function, create a clone
+// that takes 2 extra arguments(same as a var arg function).
+// Modify call sites.
 bool TypeChecks::visitAddressTakenFunction(Module &M, Function &F) {
   // Clone function
   // 1. Create the new argument types vector
@@ -560,11 +636,16 @@
   // 5. Perform the cloning
   SmallVector<ReturnInst*, 100>Returns;
   CloneFunctionInto(NewF, &F, ValueMap, Returns);
+  // Store in the map of original -> cloned function
   IndFunctionsMap[&F] = NewF;
 
   // Find all uses of the function
   for(Value::use_iterator ui = F.use_begin(), ue = F.use_end();
       ui != ue;)  {
+    if(isa<InvokeInst>(ui)) {
+      ui->dump();
+      assert(0 && "Handle invoke inst here");
+    }
     // Check for call sites
     CallInst *CI = dyn_cast<CallInst>(ui++);
     if(!CI)
@@ -1067,6 +1148,7 @@
   for (; I != E; ++I) {
     OS << "  ";
     WriteTypeSymbolic(OS, I->first, M);
+    OS << " : " << I->second;
     OS << '\n';
   }
 
@@ -1681,6 +1763,7 @@
   return false;
 }
 
+// Add extra arguments to each indirect call site
 bool TypeChecks::visitIndirectCallSite(Module &M, Instruction *I) {
   // add the number of arguments as the first argument
   const Type* OrigType = I->getOperand(0)->getType();
@@ -1759,10 +1842,7 @@
     I->eraseFromParent();
 
   }
-
-
-  // add they types of the argument as the second argument
-  return false;
+  return true;
 }
 
 bool TypeChecks::visitInputFunctionValue(Module &M, Value *V, Instruction *CI) {
@@ -1829,8 +1909,10 @@
   numLoadChecks++;
   return true;
 }
+
 // AI - metadata
 // BCI - ptr
+// I - instruction whose uses to instrument
 bool TypeChecks::visitUses(Instruction *I, AllocaInst *AI, CastInst *BCI) {
   for(Value::use_iterator II = I->use_begin(); II != I->use_end(); ++II) {
     if(DisablePtrCmpChecks) {
@@ -1869,6 +1951,7 @@
   }
   return true;
 }
+
 // Insert runtime checks before all store instructions.
 bool TypeChecks::visitStoreInst(Module &M, StoreInst &SI) {
   // Cast the pointer operand to i8* for the runtime function.





More information about the llvm-commits mailing list