[llvm-commits] CVS: llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp

Chris Lattner sabre at nondot.org
Sat Apr 7 14:18:08 PDT 2007



Changes in directory llvm/lib/Transforms/IPO:

SimplifyLibCalls.cpp updated: 1.106 -> 1.107
---
Log message:

Fix problems in the sprintf optimizer


---
Diffs of the changes:  (+60 -81)

 SimplifyLibCalls.cpp |  141 +++++++++++++++++++++------------------------------
 1 files changed, 60 insertions(+), 81 deletions(-)


Index: llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp
diff -u llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp:1.106 llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp:1.107
--- llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp:1.106	Sat Apr  7 16:04:50 2007
+++ llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp	Sat Apr  7 16:17:51 2007
@@ -1276,8 +1276,7 @@
     if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4)
       return false;
 
-    // All the optimizations depend on the length of the second argument and the
-    // fact that it is a constant string array. Check that now
+    // All the optimizations depend on the format string.
     uint64_t FormatLen, FormatStartIdx;
     ConstantArray *CA = 0;
     if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx))
@@ -1368,108 +1367,88 @@
   SPrintFOptimization() : LibCallOptimization("sprintf",
       "Number of 'sprintf' calls simplified") {}
 
-  /// @brief Make sure that the "fprintf" function has the right prototype
-  virtual bool ValidateCalledFunction(const Function *f, SimplifyLibCalls &SLC){
-    // Just make sure this has at least 2 arguments
-    return (f->getReturnType() == Type::Int32Ty && f->arg_size() >= 2);
+  /// @brief Make sure that the "sprintf" function has the right prototype
+  virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){
+    const FunctionType *FT = F->getFunctionType();
+    return FT->getNumParams() == 2 &&  // two fixed arguments.
+           FT->getParamType(1) == PointerType::get(Type::Int8Ty) &&
+           FT->getParamType(0) == FT->getParamType(1) &&
+           isa<IntegerType>(FT->getReturnType());
   }
 
   /// @brief Perform the sprintf optimization.
-  virtual bool OptimizeCall(CallInst *ci, SimplifyLibCalls &SLC) {
+  virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) {
     // If the call has more than 3 operands, we can't optimize it
-    if (ci->getNumOperands() > 4 || ci->getNumOperands() < 3)
+    if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4)
       return false;
 
-    // All the optimizations depend on the length of the second argument and the
-    // fact that it is a constant string array. Check that now
-    uint64_t len, StartIdx;
-    ConstantArray* CA = 0;
-    if (!GetConstantStringInfo(ci->getOperand(2), CA, len, StartIdx))
+    uint64_t FormatLen, FormatStartIdx;
+    ConstantArray *CA = 0;
+    if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx))
       return false;
-
-    if (ci->getNumOperands() == 3) {
-      if (len == 0) {
-        // If the length is 0, we just need to store a null byte
-        new StoreInst(ConstantInt::get(Type::Int8Ty,0),ci->getOperand(1),ci);
-        return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,0));
-      }
-
+    
+    if (CI->getNumOperands() == 3) {
+      if (!CA->isCString()) return false;
+      
       // Make sure there's no % in the constant array
-      for (unsigned i = 0; i < len; ++i) {
-        if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(i))) {
-          // Check for the null terminator
-          if (CI->getZExtValue() == '%')
-            return false; // we found a %, can't optimize
-        } else {
-          return false; // initializer is not constant int, can't optimize
-        }
-      }
-
-      // Increment length because we want to copy the null byte too
-      len++;
-
+      std::string S = CA->getAsString();
+      for (unsigned i = FormatStartIdx, e = S.size(); i != e; ++i)
+        if (S[i] == '%')
+          return false; // we found a format specifier
+      
       // sprintf(str,fmt) -> llvm.memcpy(str,fmt,strlen(fmt),1)
-      Value *args[4] = {
-        ci->getOperand(1),
-        ci->getOperand(2),
-        ConstantInt::get(SLC.getIntPtrType(),len),
+      Value *MemCpyArgs[] = {
+        CI->getOperand(1), CI->getOperand(2),
+        ConstantInt::get(SLC.getIntPtrType(), FormatLen+1), // Copy the nul byte
         ConstantInt::get(Type::Int32Ty, 1)
       };
-      new CallInst(SLC.get_memcpy(), args, 4, "", ci);
-      return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,len));
+      new CallInst(SLC.get_memcpy(), MemCpyArgs, 4, "", CI);
+      return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), FormatLen));
     }
 
-    // The remaining optimizations require the format string to be length 2
-    // "%s" or "%c".
-    if (len != 2)
+    // The remaining optimizations require the format string to be "%s" or "%c".
+    if (FormatLen != 2 ||
+        cast<ConstantInt>(CA->getOperand(FormatStartIdx))->getZExtValue() !='%')
       return false;
 
-    // The first character has to be a %
-    if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(0)))
-      if (CI->getZExtValue() != '%')
-        return false;
-
     // Get the second character and switch on its value
-    ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(1));
-    switch (CI->getZExtValue()) {
+    switch (cast<ConstantInt>(CA->getOperand(1))->getZExtValue()) {
+    case 'c': {
+      // sprintf(dest,"%c",chr) -> store chr, dest
+      Value *V = CastInst::createTruncOrBitCast(CI->getOperand(3),
+                                                Type::Int8Ty, "char", CI);
+      new StoreInst(V, CI->getOperand(1), CI);
+      Value *Ptr = new GetElementPtrInst(CI->getOperand(1),
+                                         ConstantInt::get(Type::Int32Ty, 1),
+                                         CI->getOperand(1)->getName()+".end",
+                                         CI);
+      new StoreInst(ConstantInt::get(Type::Int8Ty,0), Ptr, CI);
+      return ReplaceCallWith(CI, ConstantInt::get(Type::Int32Ty, 1));
+    }
     case 's': {
       // sprintf(dest,"%s",str) -> llvm.memcpy(dest, str, strlen(str)+1, 1)
       Value *Len = new CallInst(SLC.get_strlen(),
-                                CastToCStr(ci->getOperand(3), ci),
-                                ci->getOperand(3)->getName()+".len", ci);
-      Value *Len1 = BinaryOperator::createAdd(Len,
-                                            ConstantInt::get(Len->getType(), 1),
-                                              Len->getName()+"1", ci);
-      if (Len1->getType() != SLC.getIntPtrType())
-        Len1 = CastInst::createIntegerCast(Len1, SLC.getIntPtrType(), false,
-                                           Len1->getName(), ci);
-      Value *args[4] = {
-        CastToCStr(ci->getOperand(1), ci),
-        CastToCStr(ci->getOperand(3), ci),
-        Len1,
-        ConstantInt::get(Type::Int32Ty,1)
+                                CastToCStr(CI->getOperand(3), CI),
+                                CI->getOperand(3)->getName()+".len", CI);
+      Value *UnincLen = Len;
+      Len = BinaryOperator::createAdd(Len, ConstantInt::get(Len->getType(), 1),
+                                      Len->getName()+"1", CI);
+      Value *MemcpyArgs[4] = {
+        CI->getOperand(1),
+        CastToCStr(CI->getOperand(3), CI),
+        Len,
+        ConstantInt::get(Type::Int32Ty, 1)
       };
-      new CallInst(SLC.get_memcpy(), args, 4, "", ci);
+      new CallInst(SLC.get_memcpy(), MemcpyArgs, 4, "", CI);
       
       // The strlen result is the unincremented number of bytes in the string.
-      if (!ci->use_empty()) {
-        if (Len->getType() != ci->getType())
-          Len = CastInst::createIntegerCast(Len, ci->getType(), false, 
-                                            Len->getName(), ci);
-        ci->replaceAllUsesWith(Len);
+      if (!CI->use_empty()) {
+        if (UnincLen->getType() != CI->getType())
+          UnincLen = CastInst::createIntegerCast(UnincLen, CI->getType(), false, 
+                                                 Len->getName(), CI);
+        CI->replaceAllUsesWith(UnincLen);
       }
-      return ReplaceCallWith(ci, 0);
-    }
-    case 'c': {
-      // sprintf(dest,"%c",chr) -> store chr, dest
-      CastInst* cast = CastInst::createTruncOrBitCast(
-          ci->getOperand(3), Type::Int8Ty, "char", ci);
-      new StoreInst(cast, ci->getOperand(1), ci);
-      GetElementPtrInst* gep = new GetElementPtrInst(ci->getOperand(1),
-        ConstantInt::get(Type::Int32Ty,1),ci->getOperand(1)->getName()+".end",
-        ci);
-      new StoreInst(ConstantInt::get(Type::Int8Ty,0),gep,ci);
-      return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty, 1));
+      return ReplaceCallWith(CI, 0);
     }
     }
     return false;






More information about the llvm-commits mailing list