[llvm-commits] CVS: llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp

Chris Lattner sabre at nondot.org
Sat Apr 7 14:05:08 PDT 2007



Changes in directory llvm/lib/Transforms/IPO:

SimplifyLibCalls.cpp updated: 1.105 -> 1.106
---
Log message:

Change CastToCStr to take a pointer instead of a reference.
Fix some miscompilations in fprintf optimizer.


---
Diffs of the changes:  (+76 -82)

 SimplifyLibCalls.cpp |  158 ++++++++++++++++++++++++---------------------------
 1 files changed, 76 insertions(+), 82 deletions(-)


Index: llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp
diff -u llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp:1.105 llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp:1.106
--- llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp:1.105	Sat Apr  7 15:19:08 2007
+++ llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp	Sat Apr  7 16:04:50 2007
@@ -393,7 +393,7 @@
 // Forward declare utility functions.
 static bool GetConstantStringInfo(Value *V, ConstantArray *&Array,
                                   uint64_t &Length, uint64_t &StartIdx);
-static Value *CastToCStr(Value *V, Instruction &IP);
+static Value *CastToCStr(Value *V, Instruction *IP);
 
 /// This LibCallOptimization will find instances of a call to "exit" that occurs
 /// within the "main" function and change it to a simple "ret" instruction with
@@ -1228,7 +1228,7 @@
         return false;
 
       // printf("%s\n",str) -> puts(str)
-      new CallInst(SLC.get_puts(), CastToCStr(CI->getOperand(2), *CI),
+      new CallInst(SLC.get_puts(), CastToCStr(CI->getOperand(2), CI),
                    CI->getName(), CI);
       return ReplaceCallWith(CI, 0);
     }
@@ -1262,104 +1262,98 @@
       "Number of 'fprintf' calls simplified") {}
 
   /// @brief Make sure that the "fprintf" function has the right prototype
-  virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC){
-    // Just make sure this has at least 2 arguments
-    return (f->arg_size() >= 2);
+  virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){
+    const FunctionType *FT = F->getFunctionType();
+    return FT->getNumParams() == 2 &&  // two fixed arguments.
+           FT->getParamType(1) == PointerType::get(Type::Int8Ty) &&
+           isa<PointerType>(FT->getParamType(0)) &&
+           isa<IntegerType>(FT->getReturnType());
   }
 
   /// @brief Perform the fprintf optimization.
-  virtual bool OptimizeCall(CallInst* ci, SimplifyLibCalls& SLC) {
+  virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) {
     // If the call has more than 3 operands, we can't optimize it
-    if (ci->getNumOperands() > 4 || ci->getNumOperands() <= 2)
-      return false;
-
-    // If the result of the fprintf call is used, none of these optimizations
-    // can be made.
-    if (!ci->use_empty())
+    if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4)
       return false;
 
     // All the optimizations depend on the length of the second argument and the
     // fact that it is a constant string array. Check that now
-    uint64_t len, StartIdx;
-    ConstantArray* CA = 0;
-    if (!GetConstantStringInfo(ci->getOperand(2), CA, len, StartIdx))
+    uint64_t FormatLen, FormatStartIdx;
+    ConstantArray *CA = 0;
+    if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx))
       return false;
 
-    if (ci->getNumOperands() == 3) {
+    // IF fthis is just a format string, turn it into fwrite.
+    if (CI->getNumOperands() == 3) {
+      if (!CA->isCString()) return false;
+      
       // Make sure there's no % in the constant array
-      for (unsigned i = 0; i < len; ++i) {
-        if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(i))) {
-          // Check for the null terminator
-          if (CI->getZExtValue() == '%')
-            return false; // we found end of string
-        } else {
-          return false;
-        }
-      }
+      std::string S = CA->getAsString();
 
-      // fprintf(file,fmt) -> fwrite(fmt,strlen(fmt),file)
-      const Type* FILEptr_type = ci->getOperand(1)->getType();
+      for (unsigned i = FormatStartIdx, e = S.size(); i != e; ++i)
+        if (S[i] == '%')
+          return false; // we found a format specifier
 
-      // Make sure that the fprintf() and fwrite() functions both take the
-      // same type of char pointer.
-      if (ci->getOperand(2)->getType() != PointerType::get(Type::Int8Ty))
-        return false;
+      // fprintf(file,fmt) -> fwrite(fmt,strlen(fmt),file)
+      const Type *FILEty = CI->getOperand(1)->getType();
 
-      Value* args[4] = {
-        ci->getOperand(2),
-        ConstantInt::get(SLC.getIntPtrType(),len),
-        ConstantInt::get(SLC.getIntPtrType(),1),
-        ci->getOperand(1)
+      Value *FWriteArgs[] = {
+        CI->getOperand(2),
+        ConstantInt::get(SLC.getIntPtrType(), FormatLen),
+        ConstantInt::get(SLC.getIntPtrType(), 1),
+        CI->getOperand(1)
       };
-      new CallInst(SLC.get_fwrite(FILEptr_type), args, 4, ci->getName(), ci);
-      return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,len));
+      new CallInst(SLC.get_fwrite(FILEty), FWriteArgs, 4, CI->getName(), CI);
+      return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), FormatLen));
     }
-
-    // The remaining optimizations require the format string to be length 2
+    
+    // The remaining optimizations require the format string to be length 2:
     // "%s" or "%c".
-    if (len != 2)
+    if (FormatLen != 2)
       return false;
 
-    // The first character has to be a %
-    if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(0)))
-      if (CI->getZExtValue() != '%')
-        return false;
+    // The first character has to be a % for us to handle it.
+    if (cast<ConstantInt>(CA->getOperand(FormatStartIdx))->getZExtValue() !='%')
+      return false;
 
     // Get the second character and switch on its value
-    ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(1));
-    switch (CI->getZExtValue()) {
-      case 's': {
-        uint64_t len, StartIdx;
-        ConstantArray* CA = 0;
-        if (GetConstantStringInfo(ci->getOperand(3), CA, len, StartIdx)) {
-          // fprintf(file,"%s",str) -> fwrite(str,strlen(str),1,file)
-          const Type* FILEptr_type = ci->getOperand(1)->getType();
-          Value* args[4] = {
-            CastToCStr(ci->getOperand(3), *ci),
-            ConstantInt::get(SLC.getIntPtrType(), len),
-            ConstantInt::get(SLC.getIntPtrType(), 1),
-            ci->getOperand(1)
-          };
-          new CallInst(SLC.get_fwrite(FILEptr_type), args, 4,ci->getName(), ci);
-          return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty, len));
-        }
-        // fprintf(file,"%s",str) -> fputs(str,file)
-        const Type* FILEptr_type = ci->getOperand(1)->getType();
-        new CallInst(SLC.get_fputs(FILEptr_type),
-                     CastToCStr(ci->getOperand(3), *ci),
-                     ci->getOperand(1), ci->getName(),ci);
-        return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,len));
-      }
-      case 'c': {
-        // fprintf(file,"%c",c) -> fputc(c,file)
-        const Type* FILEptr_type = ci->getOperand(1)->getType();
-        CastInst* cast = CastInst::createSExtOrBitCast(
-            ci->getOperand(3), Type::Int32Ty, CI->getName()+".int", ci);
-        new CallInst(SLC.get_fputc(FILEptr_type), cast,ci->getOperand(1),"",ci);
-        return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,1));
+    switch(cast<ConstantInt>(CA->getOperand(FormatStartIdx+1))->getZExtValue()){
+    case 'c': {
+      // fprintf(file,"%c",c) -> fputc(c,file)
+      const Type *FILETy = CI->getOperand(1)->getType();
+      Value *C = CastInst::createZExtOrBitCast(CI->getOperand(3), Type::Int32Ty,
+                                               CI->getName()+".int", CI);
+      new CallInst(SLC.get_fputc(FILETy), C, CI->getOperand(1), "", CI);
+      return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 1));
+    }
+    case 's': {
+      const Type *FILETy = CI->getOperand(1)->getType();
+      uint64_t LitStrLen, LitStartIdx;
+      ConstantArray *CA = 0;
+      if (GetConstantStringInfo(CI->getOperand(3), CA, LitStrLen, LitStartIdx)){
+        // fprintf(file,"%s",str) -> fwrite(str,strlen(str),1,file)
+        Value *FWriteArgs[] = {
+          CastToCStr(CI->getOperand(3), CI),
+          ConstantInt::get(SLC.getIntPtrType(), LitStrLen),
+          ConstantInt::get(SLC.getIntPtrType(), 1),
+          CI->getOperand(1)
+        };
+        new CallInst(SLC.get_fwrite(FILETy), FWriteArgs, 4, CI->getName(), CI);
+        return ReplaceCallWith(CI, ConstantInt::get(Type::Int32Ty, LitStrLen));
       }
-      default:
+      
+      // If the result of the fprintf call is used, we can't do this.
+      // TODO: we could insert a strlen call.
+      if (!CI->use_empty())
         return false;
+      
+      // fprintf(file,"%s",str) -> fputs(str,file)
+      new CallInst(SLC.get_fputs(FILETy), CastToCStr(CI->getOperand(3), CI),
+                   CI->getOperand(1), CI->getName(), CI);
+      return ReplaceCallWith(CI, 0);
+    }
+    default:
+      return false;
     }
   }
 } FPrintFOptimizer;
@@ -1441,7 +1435,7 @@
     case 's': {
       // sprintf(dest,"%s",str) -> llvm.memcpy(dest, str, strlen(str)+1, 1)
       Value *Len = new CallInst(SLC.get_strlen(),
-                                CastToCStr(ci->getOperand(3), *ci),
+                                CastToCStr(ci->getOperand(3), ci),
                                 ci->getOperand(3)->getName()+".len", ci);
       Value *Len1 = BinaryOperator::createAdd(Len,
                                             ConstantInt::get(Len->getType(), 1),
@@ -1450,8 +1444,8 @@
         Len1 = CastInst::createIntegerCast(Len1, SLC.getIntPtrType(), false,
                                            Len1->getName(), ci);
       Value *args[4] = {
-        CastToCStr(ci->getOperand(1), *ci),
-        CastToCStr(ci->getOperand(3), *ci),
+        CastToCStr(ci->getOperand(1), ci),
+        CastToCStr(ci->getOperand(3), ci),
         Len1,
         ConstantInt::get(Type::Int32Ty,1)
       };
@@ -1946,12 +1940,12 @@
 /// CastToCStr - Return V if it is an sbyte*, otherwise cast it to sbyte*,
 /// inserting the cast before IP, and return the cast.
 /// @brief Cast a value to a "C" string.
-static Value *CastToCStr(Value *V, Instruction &IP) {
+static Value *CastToCStr(Value *V, Instruction *IP) {
   assert(isa<PointerType>(V->getType()) && 
          "Can't cast non-pointer type to C string type");
   const Type *SBPTy = PointerType::get(Type::Int8Ty);
   if (V->getType() != SBPTy)
-    return new BitCastInst(V, SBPTy, V->getName(), &IP);
+    return new BitCastInst(V, SBPTy, V->getName(), IP);
   return V;
 }
 






More information about the llvm-commits mailing list