[llvm-commits] [llvm] r157649 - /llvm/trunk/lib/Transforms/Scalar/BoundsChecking.cpp

Nuno Lopes nunoplopes at sapo.pt
Tue May 29 15:32:52 PDT 2012


Author: nlopes
Date: Tue May 29 17:32:51 2012
New Revision: 157649

URL: http://llvm.org/viewvc/llvm-project?rev=157649&view=rev
Log:
bounds checking:
 - hoist checks out of loops where SCEV is smart enough
 - add additional statistics to measure how much we loose for not supporting interprocedural and pointers loaded from memory

Modified:
    llvm/trunk/lib/Transforms/Scalar/BoundsChecking.cpp

Modified: llvm/trunk/lib/Transforms/Scalar/BoundsChecking.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/BoundsChecking.cpp?rev=157649&r1=157648&r2=157649&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/BoundsChecking.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/BoundsChecking.cpp Tue May 29 17:32:51 2012
@@ -15,7 +15,10 @@
 #define DEBUG_TYPE "bounds-checking"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/ADT/Statistic.h"
-#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpander.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/InstIterator.h"
 #include "llvm/Support/IRBuilder.h"
@@ -34,6 +37,8 @@
 STATISTIC(ChecksAdded, "Bounds checks added");
 STATISTIC(ChecksSkipped, "Bounds checks skipped");
 STATISTIC(ChecksUnable, "Bounds checks unable to add");
+STATISTIC(ChecksUnableInterproc, "Bounds checks unable to add (interprocedural)");
+STATISTIC(ChecksUnableLoad, "Bounds checks unable to add (LoadInst)");
 
 typedef IRBuilder<true, TargetFolder> BuilderTy;
 
@@ -53,10 +58,14 @@
 
     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
       AU.addRequired<TargetData>();
+      AU.addRequired<LoopInfo>();
+      AU.addRequired<ScalarEvolution>();
     }
 
   private:
     const TargetData *TD;
+    LoopInfo *LI;
+    ScalarEvolution *SE;
     BuilderTy *Builder;
     Function *Fn;
     BasicBlock *TrapBB;
@@ -71,8 +80,11 @@
 }
 
 char BoundsChecking::ID = 0;
-INITIALIZE_PASS(BoundsChecking, "bounds-checking", "Run-time bounds checking",
-                false, false)
+INITIALIZE_PASS_BEGIN(BoundsChecking, "bounds-checking",
+                      "Run-time bounds checking", false, false)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
+INITIALIZE_PASS_END(BoundsChecking, "bounds-checking",
+                      "Run-time bounds checking", false, false)
 
 
 /// getTrapBB - create a basic block that traps. All overflowing conditions
@@ -153,8 +165,10 @@
 
   // function arguments
   } else if (Argument *A = dyn_cast<Argument>(Alloc)) {
-    if (!A->hasByValAttr())
+    if (!A->hasByValAttr()) {
+      ++ChecksUnableInterproc;
       return Dunno;
+    }
 
     PointerType *PT = cast<PointerType>(A->getType());
     Size = TD->getTypeAllocSize(PT->getElementType());
@@ -268,7 +282,6 @@
       }
       SizeValue = Builder->CreateMul(SizeValue, Arg);
     }
-
     return NotConst;
 
     // TODO: handle more standard functions:
@@ -276,9 +289,12 @@
     // - strcpy / strncpy
     // - memcpy / memmove
     // - strcat / strncat
+
+  } else if (isa<LoadInst>(Alloc)) {
+    ++ChecksUnableLoad;
+    return Dunno;
   }
 
-  DEBUG(dbgs() << "computeAllocSize failed:\n" << *Alloc << "\n");
   return Dunno;
 }
 
@@ -293,13 +309,30 @@
   DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize)
               << " bytes\n");
 
-  Type *SizeTy = Type::getInt64Ty(Fn->getContext());
+  Type *SizeTy = TD->getIntPtrType(Fn->getContext());
 
   // Get to the real allocated thing and offset as fast as possible.
   Ptr = Ptr->stripPointerCasts();
-  GEPOperator *GEP;
 
-  if ((GEP = dyn_cast<GEPOperator>(Ptr))) {
+  // try to hoist the check if the instruction is inside a loop
+  Value *LoopOffset = 0;
+  if (Loop *L = LI->getLoopFor(Builder->GetInsertPoint()->getParent())) {
+    const SCEV *PtrSCEV  = SE->getSCEVAtScope(Ptr, L->getParentLoop());
+    const SCEV *BaseSCEV = SE->getPointerBase(PtrSCEV);
+
+    if (const SCEVUnknown *PointerBase = dyn_cast<SCEVUnknown>(BaseSCEV)) {
+      Ptr = PointerBase->getValue()->stripPointerCasts();
+      Instruction *InsertPoint = L->getLoopPreheader()->getFirstInsertionPt();
+      Builder->SetInsertPoint(InsertPoint);
+
+      SCEVExpander Expander(*SE, "bounds-checking");
+      const SCEV *OffsetSCEV = SE->getMinusSCEV(PtrSCEV, PointerBase);
+      LoopOffset = Expander.expandCodeFor(OffsetSCEV, SizeTy, InsertPoint);
+    }
+  }
+
+  GEPOperator *GEP = dyn_cast<GEPOperator>(Ptr);
+  if (GEP) {
     // check if we will be able to get the offset
     if (!GEP->hasAllConstantIndices() && Penalty < 2) {
       ++ChecksUnable;
@@ -312,6 +345,7 @@
   Value *SizeValue = 0;
   ConstTriState ConstAlloc = computeAllocSize(Ptr, Size, SizeValue);
   if (ConstAlloc == Dunno) {
+    DEBUG(dbgs() << "computeAllocSize failed:\n" << *Ptr << "\n");
     ++ChecksUnable;
     return false;
   }
@@ -330,7 +364,7 @@
     }
   }
 
-  if (!OffsetValue && ConstAlloc == Const) {
+  if (!LoopOffset && !OffsetValue && ConstAlloc == Const) {
     if (Size < Offset || (Size - Offset) < NeededSize) {
       // Out of bounds
       emitBranchToTrap();
@@ -342,9 +376,7 @@
     return false;
   }
 
-  if (OffsetValue)
-    OffsetValue = Builder->CreateZExt(OffsetValue, SizeTy);
-  else
+  if (!OffsetValue)
     OffsetValue = ConstantInt::get(SizeTy, Offset);
 
   if (SizeValue)
@@ -352,6 +384,10 @@
   else
     SizeValue = ConstantInt::get(SizeTy, Size);
 
+  // add the loop offset if the check was hoisted
+  if (LoopOffset)
+    OffsetValue = Builder->CreateAdd(OffsetValue, LoopOffset);
+
   Value *NeededSizeVal = ConstantInt::get(SizeTy, NeededSize);
   Value *ObjSize = Builder->CreateSub(SizeValue, OffsetValue);
   Value *Cmp1 = Builder->CreateICmpULT(SizeValue, OffsetValue);
@@ -365,6 +401,8 @@
 
 bool BoundsChecking::runOnFunction(Function &F) {
   TD = &getAnalysis<TargetData>();
+  LI = &getAnalysis<LoopInfo>();
+  SE = &getAnalysis<ScalarEvolution>();
 
   TrapBB = 0;
   Fn = &F;





More information about the llvm-commits mailing list