[llvm-commits] CVS: llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp
    Chris Lattner 
    sabre at nondot.org
       
    Sat Apr  7 14:58:19 PDT 2007
    
    
  
Changes in directory llvm/lib/Transforms/IPO:
SimplifyLibCalls.cpp updated: 1.107 -> 1.108
---
Log message:
Significantly simplify the clients of GetConstantStringInfo, by having it
just return the string itself.
---
Diffs of the changes:  (+153 -234)
 SimplifyLibCalls.cpp |  387 ++++++++++++++++++++-------------------------------
 1 files changed, 153 insertions(+), 234 deletions(-)
Index: llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp
diff -u llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp:1.107 llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp:1.108
--- llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp:1.107	Sat Apr  7 16:17:51 2007
+++ llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp	Sat Apr  7 16:58:02 2007
@@ -391,8 +391,7 @@
 namespace {
 
 // Forward declare utility functions.
-static bool GetConstantStringInfo(Value *V, ConstantArray *&Array,
-                                  uint64_t &Length, uint64_t &StartIdx);
+static bool GetConstantStringInfo(Value *V, std::string &Str);
 static Value *CastToCStr(Value *V, Instruction *IP);
 
 /// This LibCallOptimization will find instances of a call to "exit" that occurs
@@ -465,19 +464,12 @@
 public:
 
   /// @brief Make sure that the "strcat" function has the right prototype
-  virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC){
-    if (f->getReturnType() == PointerType::get(Type::Int8Ty))
-      if (f->arg_size() == 2)
-      {
-        Function::const_arg_iterator AI = f->arg_begin();
-        if (AI++->getType() == PointerType::get(Type::Int8Ty))
-          if (AI->getType() == PointerType::get(Type::Int8Ty))
-          {
-            // Indicate this is a suitable call type.
-            return true;
-          }
-      }
-    return false;
+  virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){
+    const FunctionType *FT = F->getFunctionType();
+    return FT->getNumParams() == 2 &&
+           FT->getReturnType() == PointerType::get(Type::Int8Ty) &&
+           FT->getParamType(0) == FT->getReturnType() &&
+           FT->getParamType(1) == FT->getReturnType();
   }
 
   /// @brief Optimize the strcat library function
@@ -488,18 +480,16 @@
 
     // Extract the initializer (while making numerous checks) from the
     // source operand of the call to strcat.
-    uint64_t SrcLength, StartIdx;
-    ConstantArray *Arr;
-    if (!GetConstantStringInfo(Src, Arr, SrcLength, StartIdx))
+    std::string SrcStr;
+    if (!GetConstantStringInfo(Src, SrcStr))
       return false;
 
     // Handle the simple, do-nothing case
-    if (SrcLength == 0)
+    if (SrcStr.empty())
       return ReplaceCallWith(CI, Dst);
 
     // We need to find the end of the destination string.  That's where the
-    // memory is to be moved to. We just generate a call to strlen (further
-    // optimized in another pass).
+    // memory is to be moved to. We just generate a call to strlen.
     CallInst *DstLen = new CallInst(SLC.get_strlen(), Dst,
                                     Dst->getName()+".len", CI);
 
@@ -512,7 +502,7 @@
     // do the concatenation for us.
     Value *Vals[] = {
       Dst, Src,
-      ConstantInt::get(SLC.getIntPtrType(), SrcLength+1), // copy nul term.
+      ConstantInt::get(SLC.getIntPtrType(), SrcStr.size()+1), // copy nul byte.
       ConstantInt::get(Type::Int32Ty, 1)  // alignment
     };
     new CallInst(SLC.get_memcpy(), Vals, 4, "", CI);
@@ -542,10 +532,8 @@
   /// @brief Perform the strchr optimizations
   virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) {
     // Check that the first argument to strchr is a constant array of sbyte.
-    // If it is, get the length and data, otherwise return false.
-    uint64_t StrLength, StartIdx;
-    ConstantArray *CA = 0;
-    if (!GetConstantStringInfo(CI->getOperand(1), CA, StrLength, StartIdx))
+    std::string Str;
+    if (!GetConstantStringInfo(CI->getOperand(1), Str))
       return false;
 
     // If the second operand is not constant, just lower this to memchr since we
@@ -555,35 +543,26 @@
       Value *Args[3] = {
         CI->getOperand(1),
         CI->getOperand(2),
-        ConstantInt::get(SLC.getIntPtrType(), StrLength+1)
+        ConstantInt::get(SLC.getIntPtrType(), Str.size()+1)
       };
       return ReplaceCallWith(CI, new CallInst(SLC.get_memchr(), Args, 3,
                                               CI->getName(), CI));
     }
 
+    // strchr can find the nul character.
+    Str += '\0';
+    
     // Get the character we're looking for
-    int64_t CharValue = CSI->getSExtValue();
+    char CharValue = CSI->getSExtValue();
 
-    if (StrLength == 0) {
-      // If the length of the string is zero, and we are searching for zero,
-      // return the input pointer.
-      if (CharValue == 0)
-        return ReplaceCallWith(CI, CI->getOperand(1));
-      // Otherwise, char wasn't found.
-      return ReplaceCallWith(CI, Constant::getNullValue(CI->getType()));
-    }
-    
     // Compute the offset
     uint64_t i = 0;
     while (1) {
-      assert(i <= StrLength && "Didn't find null terminator?");
-      if (ConstantInt *C = dyn_cast<ConstantInt>(CA->getOperand(i+StartIdx))) {
-        // Did we find our match?
-        if (C->getSExtValue() == CharValue)
-          break;
-        if (C->isZero()) // We found the end of the string. strchr returns null.
-          return ReplaceCallWith(CI, Constant::getNullValue(CI->getType()));
-      }
+      if (i == Str.size())    // Didn't find the char.  strchr returns null.
+        return ReplaceCallWith(CI, Constant::getNullValue(CI->getType()));
+      // Did we find our match?
+      if (Str[i] == CharValue)
+        break;
       ++i;
     }
 
@@ -624,34 +603,29 @@
     if (Str1P == Str2P)      // strcmp(x,x)  -> 0
       return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 0));
 
-    uint64_t Str1Len, Str1StartIdx;
-    ConstantArray *A1;
-    bool Str1IsCst = GetConstantStringInfo(Str1P, A1, Str1Len, Str1StartIdx);
-    if (Str1IsCst && Str1Len == 0) {
+    std::string Str1;
+    if (!GetConstantStringInfo(Str1P, Str1))
+      return false;
+    if (Str1.empty()) {
       // strcmp("", x) -> *x
       Value *V = new LoadInst(Str2P, CI->getName()+".load", CI);
       V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI);
       return ReplaceCallWith(CI, V);
     }
 
-    uint64_t Str2Len, Str2StartIdx;
-    ConstantArray* A2;
-    bool Str2IsCst = GetConstantStringInfo(Str2P, A2, Str2Len, Str2StartIdx);
-    if (Str2IsCst && Str2Len == 0) {
+    std::string Str2;
+    if (!GetConstantStringInfo(Str2P, Str2))
+      return false;
+    if (Str2.empty()) {
       // strcmp(x,"") -> *x
       Value *V = new LoadInst(Str1P, CI->getName()+".load", CI);
       V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI);
       return ReplaceCallWith(CI, V);
     }
 
-    if (Str1IsCst && Str2IsCst && A1->isCString() && A2->isCString()) {
-      // strcmp(x, y)  -> cnst  (if both x and y are constant strings)
-      std::string S1 = A1->getAsString();
-      std::string S2 = A2->getAsString();
-      int R = strcmp(S1.c_str()+Str1StartIdx, S2.c_str()+Str2StartIdx);
-      return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), R));
-    }
-    return false;
+    // strcmp(x, y)  -> cnst  (if both x and y are constant strings)
+    int R = strcmp(Str1.c_str(), Str2.c_str());
+    return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), R));
   }
 } StrCmpOptimizer;
 
@@ -681,7 +655,7 @@
     // because the call is a no-op.
     Value *Str1P = CI->getOperand(1);
     Value *Str2P = CI->getOperand(2);
-    if (Str1P == Str2P)  // strncmp(x,x)  -> 0
+    if (Str1P == Str2P)  // strncmp(x,x, n)  -> 0
       return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 0));
     
     // Check the length argument, if it is Constant zero then the strings are
@@ -692,40 +666,32 @@
     else
       return false;
     
-    if (Length == 0) {
-      // strncmp(x,y,0)   -> 0
+    if (Length == 0) // strncmp(x,y,0)   -> 0
       return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 0));
-    }
     
-    uint64_t Str1Len, Str1StartIdx;
-    ConstantArray *A1;
-    bool Str1IsCst = GetConstantStringInfo(Str1P, A1, Str1Len, Str1StartIdx);
-    if (Str1IsCst && Str1Len == 0) {
-      // strncmp("", x) -> *x
+    std::string Str1;
+    if (!GetConstantStringInfo(Str1P, Str1))
+      return false;
+    if (Str1.empty()) {
+      // strncmp("", x, n) -> *x
       Value *V = new LoadInst(Str2P, CI->getName()+".load", CI);
       V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI);
       return ReplaceCallWith(CI, V);
     }
     
-    uint64_t Str2Len, Str2StartIdx;
-    ConstantArray* A2;
-    bool Str2IsCst = GetConstantStringInfo(Str2P, A2, Str2Len, Str2StartIdx);
-    if (Str2IsCst && Str2Len == 0) {
-      // strncmp(x,"") -> *x
+    std::string Str2;
+    if (!GetConstantStringInfo(Str2P, Str2))
+      return false;
+    if (Str2.empty()) {
+      // strncmp(x, "", n) -> *x
       Value *V = new LoadInst(Str1P, CI->getName()+".load", CI);
       V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI);
       return ReplaceCallWith(CI, V);
     }
     
-    if (Str1IsCst && Str2IsCst && A1->isCString() &&
-        A2->isCString()) {
-      // strncmp(x, y)  -> cnst  (if both x and y are constant strings)
-      std::string S1 = A1->getAsString();
-      std::string S2 = A2->getAsString();
-      int R = strncmp(S1.c_str()+Str1StartIdx, S2.c_str()+Str2StartIdx, Length);
-      return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), R));
-    }
-    return false;
+    // strncmp(x, y, n)  -> cnst  (if both x and y are constant strings)
+    int R = strncmp(Str1.c_str(), Str2.c_str(), Length);
+    return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), R));
   }
 } StrNCmpOptimizer;
 
@@ -764,14 +730,13 @@
     }
     
     // Get the length of the constant string referenced by the Src operand.
-    uint64_t SrcLen, SrcStartIdx;
-    ConstantArray *SrcArr;
-    if (!GetConstantStringInfo(Src, SrcArr, SrcLen, SrcStartIdx))
+    std::string SrcStr;
+    if (!GetConstantStringInfo(Src, SrcStr))
       return false;
-
+    
     // If the constant string's length is zero we can optimize this by just
     // doing a store of 0 at the first byte of the destination
-    if (SrcLen == 0) {
+    if (SrcStr.size() == 0) {
       new StoreInst(ConstantInt::get(Type::Int8Ty, 0), Dst, CI);
       return ReplaceCallWith(CI, Dst);
     }
@@ -779,8 +744,8 @@
     // We have enough information to now generate the memcpy call to
     // do the concatenation for us.
     Value *MemcpyOps[] = {
-      Dst, Src,
-      ConstantInt::get(SLC.getIntPtrType(), SrcLen+1), // length including nul.
+      Dst, Src, // Pass length including nul byte.
+      ConstantInt::get(SLC.getIntPtrType(), SrcStr.size()+1),
       ConstantInt::get(Type::Int32Ty, 1) // alignment
     };
     new CallInst(SLC.get_memcpy(), MemcpyOps, 4, "", CI);
@@ -808,7 +773,7 @@
   /// @brief Perform the strlen optimization
   virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) {
     // Make sure we're dealing with an sbyte* here.
-    Value *Str = CI->getOperand(1);
+    Value *Src = CI->getOperand(1);
 
     // Does the call to strlen have exactly one use?
     if (CI->hasOneUse()) {
@@ -820,7 +785,7 @@
           if (Cst->getZExtValue() == 0 && Cmp->isEquality()) {
             // strlen(x) != 0 -> *x != 0
             // strlen(x) == 0 -> *x == 0
-            Value *V = new LoadInst(Str, Str->getName()+".first", CI);
+            Value *V = new LoadInst(Src, Src->getName()+".first", CI);
             V = new ICmpInst(Cmp->getPredicate(), V, 
                              ConstantInt::get(Type::Int8Ty, 0),
                              Cmp->getName()+".strlen", CI);
@@ -832,13 +797,12 @@
     }
 
     // Get the length of the constant string operand
-    uint64_t StrLen = 0, StartIdx;
-    ConstantArray *A;
-    if (!GetConstantStringInfo(CI->getOperand(1), A, StrLen, StartIdx))
+    std::string Str;
+    if (!GetConstantStringInfo(Src, Str))
       return false;
-
+      
     // strlen("xyz") -> 3 (for example)
-    return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), StrLen));
+    return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), Str.size()));
   }
 } StrLenOptimizer;
 
@@ -1200,52 +1164,43 @@
     if (CI->getNumOperands() != 3)
       return false;
 
-    // If the result of the printf call is used, none of these optimizations
-    // can be made.
-    if (!CI->use_empty())
-      return false;
-
     // All the optimizations depend on the length of the first argument and the
     // fact that it is a constant string array. Check that now
-    uint64_t FormatLen, FormatIdx;
-    ConstantArray *CA = 0;
-    if (!GetConstantStringInfo(CI->getOperand(1), CA, FormatLen, FormatIdx))
+    std::string FormatStr;
+    if (!GetConstantStringInfo(CI->getOperand(1), FormatStr))
       return false;
 
-    if (FormatLen != 2 && FormatLen != 3)
-      return false;
-
-    // The first character has to be a %
-    if (cast<ConstantInt>(CA->getOperand(FormatIdx))->getZExtValue() != '%')
+    // Only support %c or "%s\n" for now.
+    if (FormatStr.size() < 2 || FormatStr[0] != '%')
       return false;
 
     // Get the second character and switch on its value
-    switch (cast<ConstantInt>(CA->getOperand(FormatIdx+1))->getZExtValue()) {
+    switch (FormatStr[1]) {
     default:  return false;
-    case 's': {
-      if (FormatLen != 3 ||
-          cast<ConstantInt>(CA->getOperand(FormatIdx+2))->getZExtValue() !='\n')
+    case 's':
+      if (FormatStr != "%s\n" ||
+          // TODO: could insert strlen call to compute string length.
+          !CI->use_empty())
         return false;
 
       // printf("%s\n",str) -> puts(str)
       new CallInst(SLC.get_puts(), CastToCStr(CI->getOperand(2), CI),
                    CI->getName(), CI);
       return ReplaceCallWith(CI, 0);
-    }
     case 'c': {
       // printf("%c",c) -> putchar(c)
-      if (FormatLen != 2)
+      if (FormatStr.size() != 2)
         return false;
       
       Value *V = CI->getOperand(2);
       if (!isa<IntegerType>(V->getType()) ||
-          cast<IntegerType>(V->getType())->getBitWidth() < 32)
+          cast<IntegerType>(V->getType())->getBitWidth() > 32)
         return false;
 
-      V = CastInst::createSExtOrBitCast(V, Type::Int32Ty, CI->getName()+".int",
+      V = CastInst::createZExtOrBitCast(V, Type::Int32Ty, CI->getName()+".int",
                                         CI);
       new CallInst(SLC.get_putchar(), V, "", CI);
-      return ReplaceCallWith(CI, 0);
+      return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 1));
     }
     }
   }
@@ -1277,20 +1232,14 @@
       return false;
 
     // All the optimizations depend on the format string.
-    uint64_t FormatLen, FormatStartIdx;
-    ConstantArray *CA = 0;
-    if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx))
+    std::string FormatStr;
+    if (!GetConstantStringInfo(CI->getOperand(2), FormatStr))
       return false;
 
-    // IF fthis is just a format string, turn it into fwrite.
+    // If this is just a format string, turn it into fwrite.
     if (CI->getNumOperands() == 3) {
-      if (!CA->isCString()) return false;
-      
-      // Make sure there's no % in the constant array
-      std::string S = CA->getAsString();
-
-      for (unsigned i = FormatStartIdx, e = S.size(); i != e; ++i)
-        if (S[i] == '%')
+      for (unsigned i = 0, e = FormatStr.size(); i != e; ++i)
+        if (FormatStr[i] == '%')
           return false; // we found a format specifier
 
       // fprintf(file,fmt) -> fwrite(fmt,strlen(fmt),file)
@@ -1298,25 +1247,22 @@
 
       Value *FWriteArgs[] = {
         CI->getOperand(2),
-        ConstantInt::get(SLC.getIntPtrType(), FormatLen),
+        ConstantInt::get(SLC.getIntPtrType(), FormatStr.size()),
         ConstantInt::get(SLC.getIntPtrType(), 1),
         CI->getOperand(1)
       };
       new CallInst(SLC.get_fwrite(FILEty), FWriteArgs, 4, CI->getName(), CI);
-      return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), FormatLen));
+      return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 
+                                                  FormatStr.size()));
     }
     
     // The remaining optimizations require the format string to be length 2:
     // "%s" or "%c".
-    if (FormatLen != 2)
-      return false;
-
-    // The first character has to be a % for us to handle it.
-    if (cast<ConstantInt>(CA->getOperand(FormatStartIdx))->getZExtValue() !='%')
+    if (FormatStr.size() != 2 || FormatStr[0] != '%')
       return false;
 
     // Get the second character and switch on its value
-    switch(cast<ConstantInt>(CA->getOperand(FormatStartIdx+1))->getZExtValue()){
+    switch (FormatStr[1]) {
     case 'c': {
       // fprintf(file,"%c",c) -> fputc(c,file)
       const Type *FILETy = CI->getOperand(1)->getType();
@@ -1327,22 +1273,9 @@
     }
     case 's': {
       const Type *FILETy = CI->getOperand(1)->getType();
-      uint64_t LitStrLen, LitStartIdx;
-      ConstantArray *CA = 0;
-      if (GetConstantStringInfo(CI->getOperand(3), CA, LitStrLen, LitStartIdx)){
-        // fprintf(file,"%s",str) -> fwrite(str,strlen(str),1,file)
-        Value *FWriteArgs[] = {
-          CastToCStr(CI->getOperand(3), CI),
-          ConstantInt::get(SLC.getIntPtrType(), LitStrLen),
-          ConstantInt::get(SLC.getIntPtrType(), 1),
-          CI->getOperand(1)
-        };
-        new CallInst(SLC.get_fwrite(FILETy), FWriteArgs, 4, CI->getName(), CI);
-        return ReplaceCallWith(CI, ConstantInt::get(Type::Int32Ty, LitStrLen));
-      }
       
       // If the result of the fprintf call is used, we can't do this.
-      // TODO: we could insert a strlen call.
+      // TODO: we should insert a strlen call.
       if (!CI->use_empty())
         return false;
       
@@ -1382,37 +1315,34 @@
     if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4)
       return false;
 
-    uint64_t FormatLen, FormatStartIdx;
-    ConstantArray *CA = 0;
-    if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx))
+    std::string FormatStr;
+    if (!GetConstantStringInfo(CI->getOperand(2), FormatStr))
       return false;
     
     if (CI->getNumOperands() == 3) {
-      if (!CA->isCString()) return false;
-      
       // Make sure there's no % in the constant array
-      std::string S = CA->getAsString();
-      for (unsigned i = FormatStartIdx, e = S.size(); i != e; ++i)
-        if (S[i] == '%')
+      for (unsigned i = 0, e = FormatStr.size(); i != e; ++i)
+        if (FormatStr[i] == '%')
           return false; // we found a format specifier
       
       // sprintf(str,fmt) -> llvm.memcpy(str,fmt,strlen(fmt),1)
       Value *MemCpyArgs[] = {
         CI->getOperand(1), CI->getOperand(2),
-        ConstantInt::get(SLC.getIntPtrType(), FormatLen+1), // Copy the nul byte
+        ConstantInt::get(SLC.getIntPtrType(), 
+                         FormatStr.size()+1), // Copy the nul byte.
         ConstantInt::get(Type::Int32Ty, 1)
       };
       new CallInst(SLC.get_memcpy(), MemCpyArgs, 4, "", CI);
-      return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), FormatLen));
+      return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), 
+                                                  FormatStr.size()));
     }
 
     // The remaining optimizations require the format string to be "%s" or "%c".
-    if (FormatLen != 2 ||
-        cast<ConstantInt>(CA->getOperand(FormatStartIdx))->getZExtValue() !='%')
+    if (FormatStr.size() != 2 || FormatStr[0] != '%')
       return false;
 
     // Get the second character and switch on its value
-    switch (cast<ConstantInt>(CA->getOperand(1))->getZExtValue()) {
+    switch (FormatStr[2]) {
     case 'c': {
       // sprintf(dest,"%c",chr) -> store chr, dest
       Value *V = CastInst::createTruncOrBitCast(CI->getOperand(3),
@@ -1459,10 +1389,10 @@
 /// function. It looks for cases where the result of fputs is not used and the
 /// operation can be reduced to something simpler.
 /// @brief Simplify the puts library function.
-struct VISIBILITY_HIDDEN PutsOptimization : public LibCallOptimization {
+struct VISIBILITY_HIDDEN FPutsOptimization : public LibCallOptimization {
 public:
   /// @brief Default Constructor
-  PutsOptimization() : LibCallOptimization("fputs",
+  FPutsOptimization() : LibCallOptimization("fputs",
       "Number of 'fputs' calls simplified") {}
 
   /// @brief Make sure that the "fputs" function has the right prototype
@@ -1472,49 +1402,45 @@
   }
 
   /// @brief Perform the fputs optimization.
-  virtual bool OptimizeCall(CallInst* ci, SimplifyLibCalls& SLC) {
-    // If the result is used, none of these optimizations work
-    if (!ci->use_empty())
+  virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) {
+    // If the result is used, none of these optimizations work.
+    if (!CI->use_empty())
       return false;
 
     // All the optimizations depend on the length of the first argument and the
     // fact that it is a constant string array. Check that now
-    uint64_t len, StartIdx;
-    ConstantArray *CA;
-    if (!GetConstantStringInfo(ci->getOperand(1), CA, len, StartIdx))
+    std::string Str;
+    if (!GetConstantStringInfo(CI->getOperand(1), Str))
       return false;
 
-    switch (len) {
-      case 0:
-        // fputs("",F) -> noop
-        break;
-      case 1:
-      {
-        // fputs(s,F)  -> fputc(s[0],F)  (if s is constant and strlen(s) == 1)
-        const Type* FILEptr_type = ci->getOperand(2)->getType();
-        LoadInst* loadi = new LoadInst(ci->getOperand(1),
-          ci->getOperand(1)->getName()+".byte",ci);
-        CastInst* casti = new SExtInst(loadi, Type::Int32Ty, 
-                                       loadi->getName()+".int", ci);
-        new CallInst(SLC.get_fputc(FILEptr_type), casti,
-                     ci->getOperand(2), "", ci);
-        break;
-      }
-      default:
-      {
-        // fputs(s,F)  -> fwrite(s,1,len,F) (if s is constant and strlen(s) > 1)
-        const Type* FILEptr_type = ci->getOperand(2)->getType();
-        Value *parms[4] = {
-          ci->getOperand(1),
-          ConstantInt::get(SLC.getIntPtrType(),len),
-          ConstantInt::get(SLC.getIntPtrType(),1),
-          ci->getOperand(2)
-        };
-        new CallInst(SLC.get_fwrite(FILEptr_type), parms, 4, "", ci);
-        break;
-      }
+    Value *Ptr = CI->getOperand(1);
+    const Type *FILETy = CI->getOperand(2)->getType();
+    // FIXME: Remove these optimizations and fold fwrite with 0/1 length
+    // instead.
+    switch (Str.size()) {
+    case 0:
+      // fputs("",F) -> noop
+      break;
+    case 1: {
+      // fputs(s,F)  -> fputc(s[0],F)  (if s is constant and strlen(s) == 1)
+      Value *Val = new LoadInst(Ptr, Ptr->getName()+".byte", CI);
+      Val = new ZExtInst(Val, Type::Int32Ty, Val->getName()+".int", CI);
+      new CallInst(SLC.get_fputc(FILETy), Val, CI->getOperand(2), "", CI);
+      break;
+    }
+    default: {
+      // fputs(s,F)  -> fwrite(s,1,len,F) (if s is constant and strlen(s) > 1)
+      Value *FWriteParms[4] = {
+        CI->getOperand(1),
+        ConstantInt::get(SLC.getIntPtrType(), Str.size()),
+        ConstantInt::get(SLC.getIntPtrType(), 1),
+        CI->getOperand(2)
+      };
+      new CallInst(SLC.get_fwrite(FILETy), FWriteParms, 4, "", CI);
+      break;
     }
-    return ReplaceCallWith(ci, 0);  // Known to have no uses (see above).
+    }
+    return ReplaceCallWith(CI, 0);  // Known to have no uses (see above).
   }
 } PutsOptimizer;
 
@@ -1828,18 +1754,18 @@
 /// indexed, the \p Length parameter is set to the length of the null-terminated
 /// string pointed to by V, the \p StartIdx value is set to the first
 /// element of the Array that V points to, and true is returned.
-static bool GetConstantStringInfo(Value *V, ConstantArray *&Array,
-                                  uint64_t &Length, uint64_t &StartIdx) {
-  assert(V != 0 && "Invalid args to GetConstantStringInfo");
-  // Initialize results.
-  Length = 0;
-  StartIdx = 0;
-  Array = 0;
+static bool GetConstantStringInfo(Value *V, std::string &Str) {
+  // Look through noop bitcast instructions.
+  if (BitCastInst *BCI = dyn_cast<BitCastInst>(V)) {
+    if (BCI->getType() == BCI->getOperand(0)->getType())
+      return GetConstantStringInfo(BCI->getOperand(0), Str);
+    return false;
+  }
   
-  User *GEP = 0;
   // If the value is not a GEP instruction nor a constant expression with a
   // GEP instruction, then return false because ConstantArray can't occur
   // any other way
+  User *GEP = 0;
   if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(V)) {
     GEP = GEPI;
   } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
@@ -1856,8 +1782,8 @@
 
   // Check to make sure that the first operand of the GEP is an integer and
   // has value 0 so that we are sure we're indexing into the initializer.
-  if (ConstantInt* op1 = dyn_cast<ConstantInt>(GEP->getOperand(1))) {
-    if (!op1->isZero())
+  if (ConstantInt *Idx = dyn_cast<ConstantInt>(GEP->getOperand(1))) {
+    if (!Idx->isZero())
       return false;
   } else
     return false;
@@ -1865,7 +1791,7 @@
   // If the second index isn't a ConstantInt, then this is a variable index
   // into the array.  If this occurs, we can't say anything meaningful about
   // the string.
-  StartIdx = 0;
+  uint64_t StartIdx = 0;
   if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(2)))
     StartIdx = CI->getZExtValue();
   else
@@ -1883,37 +1809,30 @@
   if (isa<ConstantAggregateZero>(GlobalInit)) {
     // This is a degenerate case. The initializer is constant zero so the
     // length of the string must be zero.
-    Length = 0;
+    Str.clear();
     return true;
   }
 
   // Must be a Constant Array
-  Array = dyn_cast<ConstantArray>(GlobalInit);
+  ConstantArray *Array = dyn_cast<ConstantArray>(GlobalInit);
   if (!Array) return false;
 
   // Get the number of elements in the array
   uint64_t NumElts = Array->getType()->getNumElements();
 
-  // Traverse the constant array from start_idx (derived above) which is
+  // Traverse the constant array from StartIdx (derived above) which is
   // the place the GEP refers to in the array.
-  Length = StartIdx;
-  while (1) {
-    if (Length >= NumElts)
-      return false; // The array isn't null terminated.
-    
-    Constant *Elt = Array->getOperand(Length);
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(Elt)) {
-      // Check for the null terminator.
-      if (CI->isZero())
-        break; // we found end of string
-    } else
-      return false; // This array isn't suitable, non-int initializer
-    ++Length;
+  for (unsigned i = StartIdx; i < NumElts; ++i) {
+    Constant *Elt = Array->getOperand(i);
+    ConstantInt *CI = dyn_cast<ConstantInt>(Elt);
+    if (!CI) // This array isn't suitable, non-int initializer.
+      return false;
+    if (CI->isZero())
+      return true; // we found end of string, success!
+    Str += (char)CI->getZExtValue();
   }
   
-  // Subtract out the initial value from the length
-  Length -= StartIdx;
-  return true; // success!
+  return false; // The array isn't null terminated.
 }
 
 /// CastToCStr - Return V if it is an sbyte*, otherwise cast it to sbyte*,
    
    
More information about the llvm-commits
mailing list