[llvm] e810d55 - [ValueTracking] Make getStringLenth aware of strdup

via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 12 05:48:04 PDT 2022


Author: serge-sans-paille
Date: 2022-04-12T14:47:29+02:00
New Revision: e810d558093cff40caaa1aff24d289c76c59916d

URL: https://github.com/llvm/llvm-project/commit/e810d558093cff40caaa1aff24d289c76c59916d
DIFF: https://github.com/llvm/llvm-project/commit/e810d558093cff40caaa1aff24d289c76c59916d.diff

LOG: [ValueTracking] Make getStringLenth aware of strdup

During strlen compile-time evaluation, make it possible to track size of
strduped strings.

Differential Revision: https://reviews.llvm.org/D123497

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ValueTracking.h
    llvm/lib/Analysis/MemoryBuiltins.cpp
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
    llvm/test/Transforms/InstCombine/strlen-1.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index b97d6285ea5ea..68ba9d1e0e563 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -340,7 +340,9 @@ constexpr unsigned MaxAnalysisRecursionDepth = 6;
 
   /// If we can compute the length of the string pointed to by the specified
   /// pointer, return 'len+1'.  If we can't, return 0.
-  uint64_t GetStringLength(const Value *V, unsigned CharSize = 8);
+  uint64_t GetStringLength(const Value *V,
+                           const TargetLibraryInfo *TLI = nullptr,
+                           unsigned CharSize = 8);
 
   /// This function returns call pointer argument that is considered the same by
   /// aliasing rules. You CAN'T use it to replace one value with another. If

diff  --git a/llvm/lib/Analysis/MemoryBuiltins.cpp b/llvm/lib/Analysis/MemoryBuiltins.cpp
index 151afade1faa5..e6cdf41aa4b77 100644
--- a/llvm/lib/Analysis/MemoryBuiltins.cpp
+++ b/llvm/lib/Analysis/MemoryBuiltins.cpp
@@ -374,7 +374,7 @@ llvm::getAllocSize(const CallBase *CB,
 
   // Handle strdup-like functions separately.
   if (FnData->AllocTy == StrDupLike) {
-    APInt Size(IntTyBits, GetStringLength(Mapper(CB->getArgOperand(0))));
+    APInt Size(IntTyBits, GetStringLength(Mapper(CB->getArgOperand(0)), TLI));
     if (!Size)
       return None;
 

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 75381f5f7e5fb..c4ceb91031d64 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -4199,7 +4199,8 @@ bool llvm::getConstantStringInfo(const Value *V, StringRef &Str,
 /// If we can compute the length of the string pointed to by
 /// the specified pointer, return 'len+1'.  If we can't, return 0.
 static uint64_t GetStringLengthH(const Value *V,
-                                 SmallPtrSetImpl<const PHINode*> &PHIs,
+                                 SmallPtrSetImpl<const PHINode *> &PHIs,
+                                 const TargetLibraryInfo *TLI,
                                  unsigned CharSize) {
   // Look through noop bitcast instructions.
   V = V->stripPointerCasts();
@@ -4213,7 +4214,7 @@ static uint64_t GetStringLengthH(const Value *V,
     // If it was new, see if all the input strings are the same length.
     uint64_t LenSoFar = ~0ULL;
     for (Value *IncValue : PN->incoming_values()) {
-      uint64_t Len = GetStringLengthH(IncValue, PHIs, CharSize);
+      uint64_t Len = GetStringLengthH(IncValue, PHIs, TLI, CharSize);
       if (Len == 0) return 0; // Unknown length -> unknown.
 
       if (Len == ~0ULL) continue;
@@ -4229,9 +4230,9 @@ static uint64_t GetStringLengthH(const Value *V,
 
   // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y)
   if (const SelectInst *SI = dyn_cast<SelectInst>(V)) {
-    uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs, CharSize);
+    uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs, TLI, CharSize);
     if (Len1 == 0) return 0;
-    uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs, CharSize);
+    uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs, TLI, CharSize);
     if (Len2 == 0) return 0;
     if (Len1 == ~0ULL) return Len2;
     if (Len2 == ~0ULL) return Len1;
@@ -4239,6 +4240,22 @@ static uint64_t GetStringLengthH(const Value *V,
     return Len1;
   }
 
+  if (auto *CB = dyn_cast<CallBase>(V)) {
+    Function *Callee = CB->getCalledFunction();
+    if (!Callee)
+      return 0;
+
+    LibFunc TLIFn;
+    if (!TLI || !TLI->getLibFunc(*CB->getCalledFunction(), TLIFn) ||
+        !TLI->has(TLIFn))
+      return 0;
+
+    if (TLIFn == LibFunc_strdup || TLIFn == LibFunc_dunder_strdup)
+      return GetStringLengthH(CB->getArgOperand(0), PHIs, TLI, CharSize);
+
+    return 0;
+  }
+
   // Otherwise, see if we can read the string.
   ConstantDataArraySlice Slice;
   if (!getConstantDataArrayInfo(V, Slice, CharSize))
@@ -4259,12 +4276,13 @@ static uint64_t GetStringLengthH(const Value *V,
 
 /// If we can compute the length of the string pointed to by
 /// the specified pointer, return 'len+1'.  If we can't, return 0.
-uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) {
+uint64_t llvm::GetStringLength(const Value *V, const TargetLibraryInfo *TLI,
+                               unsigned CharSize) {
   if (!V->getType()->isPointerTy())
     return 0;
 
   SmallPtrSet<const PHINode*, 32> PHIs;
-  uint64_t Len = GetStringLengthH(V, PHIs, CharSize);
+  uint64_t Len = GetStringLengthH(V, PHIs, TLI, CharSize);
   // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return
   // an empty string as a length.
   return Len == ~0ULL ? 1 : Len;

diff  --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 803ea6a2f8d26..2da6355d9eaae 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -212,7 +212,7 @@ Value *LibCallSimplifier::optimizeStrCat(CallInst *CI, IRBuilderBase &B) {
   annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
 
   // See if we can get the length of the input string.
-  uint64_t Len = GetStringLength(Src);
+  uint64_t Len = GetStringLength(Src, TLI);
   if (Len)
     annotateDereferenceableBytes(CI, 1, Len);
   else
@@ -269,7 +269,7 @@ Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilderBase &B) {
   }
 
   // See if we can get the length of the input string.
-  uint64_t SrcLen = GetStringLength(Src);
+  uint64_t SrcLen = GetStringLength(Src, TLI);
   if (SrcLen) {
     annotateDereferenceableBytes(CI, 1, SrcLen);
     --SrcLen; // Unbias length.
@@ -300,7 +300,7 @@ Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilderBase &B) {
   // of the input string and turn this into memchr.
   ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1));
   if (!CharC) {
-    uint64_t Len = GetStringLength(SrcStr);
+    uint64_t Len = GetStringLength(SrcStr, TLI);
     if (Len)
       annotateDereferenceableBytes(CI, 0, Len);
     else
@@ -387,10 +387,10 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) {
                         CI->getType());
 
   // strcmp(P, "x") -> memcmp(P, "x", 2)
-  uint64_t Len1 = GetStringLength(Str1P);
+  uint64_t Len1 = GetStringLength(Str1P, TLI);
   if (Len1)
     annotateDereferenceableBytes(CI, 0, Len1);
-  uint64_t Len2 = GetStringLength(Str2P);
+  uint64_t Len2 = GetStringLength(Str2P, TLI);
   if (Len2)
     annotateDereferenceableBytes(CI, 1, Len2);
 
@@ -464,10 +464,10 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
     return B.CreateZExt(B.CreateLoad(B.getInt8Ty(), Str1P, "strcmpload"),
                         CI->getType());
 
-  uint64_t Len1 = GetStringLength(Str1P);
+  uint64_t Len1 = GetStringLength(Str1P, TLI);
   if (Len1)
     annotateDereferenceableBytes(CI, 0, Len1);
-  uint64_t Len2 = GetStringLength(Str2P);
+  uint64_t Len2 = GetStringLength(Str2P, TLI);
   if (Len2)
     annotateDereferenceableBytes(CI, 1, Len2);
 
@@ -496,7 +496,7 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
 Value *LibCallSimplifier::optimizeStrNDup(CallInst *CI, IRBuilderBase &B) {
   Value *Src = CI->getArgOperand(0);
   ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(1));
-  uint64_t SrcLen = GetStringLength(Src);
+  uint64_t SrcLen = GetStringLength(Src, TLI);
   if (SrcLen && Size) {
     annotateDereferenceableBytes(CI, 0, SrcLen);
     if (SrcLen <= Size->getZExtValue() + 1)
@@ -513,7 +513,7 @@ Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilderBase &B) {
 
   annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
   // See if we can get the length of the input string.
-  uint64_t Len = GetStringLength(Src);
+  uint64_t Len = GetStringLength(Src, TLI);
   if (Len)
     annotateDereferenceableBytes(CI, 1, Len);
   else
@@ -544,7 +544,7 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) {
   }
 
   // See if we can get the length of the input string.
-  uint64_t Len = GetStringLength(Src);
+  uint64_t Len = GetStringLength(Src, TLI);
   if (Len)
     annotateDereferenceableBytes(CI, 1, Len);
   else
@@ -584,7 +584,7 @@ Value *LibCallSimplifier::optimizeStrNCpy(CallInst *CI, IRBuilderBase &B) {
     return Dst;
 
   // See if we can get the length of the input string.
-  uint64_t SrcLen = GetStringLength(Src);
+  uint64_t SrcLen = GetStringLength(Src, TLI);
   if (SrcLen) {
     annotateDereferenceableBytes(CI, 1, SrcLen);
     --SrcLen; // Unbias length.
@@ -633,7 +633,7 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B,
   Value *Src = CI->getArgOperand(0);
 
   // Constant folding: strlen("xyz") -> 3
-  if (uint64_t Len = GetStringLength(Src, CharSize))
+  if (uint64_t Len = GetStringLength(Src, TLI, CharSize))
     return ConstantInt::get(CI->getType(), Len - 1);
 
   // If s is a constant pointer pointing to a string literal, we can fold
@@ -688,8 +688,8 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B,
 
   // strlen(x?"foo":"bars") --> x ? 3 : 4
   if (SelectInst *SI = dyn_cast<SelectInst>(Src)) {
-    uint64_t LenTrue = GetStringLength(SI->getTrueValue(), CharSize);
-    uint64_t LenFalse = GetStringLength(SI->getFalseValue(), CharSize);
+    uint64_t LenTrue = GetStringLength(SI->getTrueValue(), TLI, CharSize);
+    uint64_t LenFalse = GetStringLength(SI->getFalseValue(), TLI, CharSize);
     if (LenTrue && LenFalse) {
       ORE.emit([&]() {
         return OptimizationRemark("instcombine", "simplify-libcalls", CI)
@@ -2511,7 +2511,7 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
       // sprintf(dest, "%s", str) -> strcpy(dest, str)
       return copyFlags(*CI, emitStrCpy(Dest, CI->getArgOperand(2), B, TLI));
 
-    uint64_t SrcLen = GetStringLength(CI->getArgOperand(2));
+    uint64_t SrcLen = GetStringLength(CI->getArgOperand(2), TLI);
     if (SrcLen) {
       B.CreateMemCpy(
           Dest, Align(1), CI->getArgOperand(2), Align(1),
@@ -2803,7 +2803,7 @@ Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilderBase &B) {
     return nullptr;
 
   // fputs(s,F) --> fwrite(s,strlen(s),1,F)
-  uint64_t Len = GetStringLength(CI->getArgOperand(0));
+  uint64_t Len = GetStringLength(CI->getArgOperand(0), TLI);
   if (!Len)
     return nullptr;
 
@@ -3247,7 +3247,7 @@ FortifiedLibCallSimplifier::isFortifiedCallFoldable(CallInst *CI,
     if (OnlyLowerUnknownSize)
       return false;
     if (StrOp) {
-      uint64_t Len = GetStringLength(CI->getArgOperand(*StrOp));
+      uint64_t Len = GetStringLength(CI->getArgOperand(*StrOp), TLI);
       // If the length is 0 we don't know how long it is and so we can't
       // remove the check.
       if (Len)
@@ -3351,7 +3351,7 @@ Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI,
     return nullptr;
 
   // Maybe we can stil fold __st[rp]cpy_chk to __memcpy_chk.
-  uint64_t Len = GetStringLength(Src);
+  uint64_t Len = GetStringLength(Src, TLI);
   if (Len)
     annotateDereferenceableBytes(CI, 1, Len);
   else

diff  --git a/llvm/test/Transforms/InstCombine/strlen-1.ll b/llvm/test/Transforms/InstCombine/strlen-1.ll
index 4f52b73b960fd..ec6e38822dd21 100644
--- a/llvm/test/Transforms/InstCombine/strlen-1.ll
+++ b/llvm/test/Transforms/InstCombine/strlen-1.ll
@@ -14,6 +14,7 @@ target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f3
 @null_hello_mid = constant [13 x i8] c"hello wor\00ld\00"
 
 declare i32 @strlen(i8*)
+declare noalias i8* @strdup(i8*)
 
 ; Check strlen(string constant) -> integer constant.
 
@@ -280,4 +281,17 @@ define i1 @strlen0_after_write_to_second_byte(i8 *%ptr) {
   ret i1 %cmp
 }
 
+; Check strlen(strdup(string constant)) -> integer constant.
+
+define i32 @test_simplify_strduped_constant() {
+; CHECK-LABEL: @test_simplify_strduped_constant(
+; CHECK-NEXT:    ret i32 5
+;
+  %hello_p = getelementptr [6 x i8], [6 x i8]* @hello, i32 0, i32 0
+  %hello_s = call i8* @strdup(i8* %hello_p)
+  %hello_l = call i32 @strlen(i8* %hello_s)
+  ret i32 %hello_l
+}
+
+
 attributes #0 = { null_pointer_is_valid }


        


More information about the llvm-commits mailing list