[llvm] 516915b - [InstCombine] Fold memchr and strchr equality with first argument

Martin Sebor via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 7 14:16:21 PDT 2022


Author: Martin Sebor
Date: 2022-07-07T15:14:23-06:00
New Revision: 516915beb5ee5012a9a8b162fc29664a8c247ec3

URL: https://github.com/llvm/llvm-project/commit/516915beb5ee5012a9a8b162fc29664a8c247ec3
DIFF: https://github.com/llvm/llvm-project/commit/516915beb5ee5012a9a8b162fc29664a8c247ec3.diff

LOG: [InstCombine] Fold memchr and strchr equality with first argument

Enhance memchr and strchr handling to simplify calls to the functions
used in equality expressions with the first argument to at most two
integer comparisons:

- memchr(A, C, N) == A to N && *A == C for either a dereferenceable
  A or a nonzero N,
- strchr(S, C) == S to *S == C for any S and C, and
- strchr(S, '\0') == 0 to true for any S

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D128939

Added: 
    

Modified: 
    llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
    llvm/test/Transforms/InstCombine/memchr-11.ll
    llvm/test/Transforms/InstCombine/strchr-4.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index f4306bb43dfd..f4093d25f3e8 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -295,31 +295,69 @@ Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilderBase &B) {
   return copyFlags(*CI, emitStrLenMemCpy(Src, Dst, SrcLen, B));
 }
 
+// Helper to transform memchr(S, C, N) == S to N && *S == C and, when
+// NBytes is null, strchr(S, C) to *S == C.  A precondition of the function
+// is that either S is dereferenceable or the value of N is nonzero.
+static Value* memChrToCharCompare(CallInst *CI, Value *NBytes,
+                                  IRBuilderBase &B, const DataLayout &DL)
+{
+  Value *Src = CI->getArgOperand(0);
+  Value *CharVal = CI->getArgOperand(1);
+
+  // Fold memchr(A, C, N) == A to N && *A == C.
+  Type *CharTy = B.getInt8Ty();
+  Value *Char0 = B.CreateLoad(CharTy, Src);
+  CharVal = B.CreateTrunc(CharVal, CharTy);
+  Value *Cmp = B.CreateICmpEQ(Char0, CharVal, "char0cmp");
+
+  if (NBytes) {
+    Value *Zero = ConstantInt::get(NBytes->getType(), 0);
+    Value *And = B.CreateICmpNE(NBytes, Zero);
+    Cmp = B.CreateLogicalAnd(And, Cmp);
+  }
+
+  Value *NullPtr = Constant::getNullValue(CI->getType());
+  return B.CreateSelect(Cmp, Src, NullPtr);
+}
+
 Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilderBase &B) {
-  Function *Callee = CI->getCalledFunction();
-  FunctionType *FT = Callee->getFunctionType();
   Value *SrcStr = CI->getArgOperand(0);
+  Value *CharVal = CI->getArgOperand(1);
   annotateNonNullNoUndefBasedOnAccess(CI, 0);
 
+  if (isOnlyUsedInEqualityComparison(CI, SrcStr))
+    return memChrToCharCompare(CI, nullptr, B, DL);
+
   // If the second operand is non-constant, see if we can compute the length
   // of the input string and turn this into memchr.
-  ConstantInt *CharC = dyn_cast<ConstantInt>(CI->getArgOperand(1));
+  ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal);
   if (!CharC) {
     uint64_t Len = GetStringLength(SrcStr);
     if (Len)
       annotateDereferenceableBytes(CI, 0, Len);
     else
       return nullptr;
+
+    Function *Callee = CI->getCalledFunction();
+    FunctionType *FT = Callee->getFunctionType();
     if (!FT->getParamType(1)->isIntegerTy(32)) // memchr needs i32.
       return nullptr;
 
     return copyFlags(
         *CI,
-        emitMemChr(SrcStr, CI->getArgOperand(1), // include nul.
+        emitMemChr(SrcStr, CharVal, // include nul.
                    ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len), B,
                    DL, TLI));
   }
 
+  if (CharC->isZero()) {
+    Value *NullPtr = Constant::getNullValue(CI->getType());
+    if (isOnlyUsedInEqualityComparison(CI, NullPtr))
+      // Pre-empt the transformation to strlen below and fold
+      // strchr(A, '\0') == null to false.
+      return B.CreateIntToPtr(B.getTrue(), CI->getType());
+  }
+
   // Otherwise, the character is a constant, see if the first argument is
   // a string literal.  If so, we can constant fold.
   StringRef Str;
@@ -1008,8 +1046,12 @@ Value *LibCallSimplifier::optimizeMemRChr(CallInst *CI, IRBuilderBase &B) {
 Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) {
   Value *SrcStr = CI->getArgOperand(0);
   Value *Size = CI->getArgOperand(2);
-  if (isKnownNonZero(Size, DL))
+
+  if (isKnownNonZero(Size, DL)) {
     annotateNonNullNoUndefBasedOnAccess(CI, 0);
+    if (isOnlyUsedInEqualityComparison(CI, SrcStr))
+      return memChrToCharCompare(CI, Size, B, DL);
+  }
 
   Value *CharVal = CI->getArgOperand(1);
   ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal);
@@ -1099,9 +1141,16 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) {
     return B.CreateSelect(And, SrcStr, Sel1, "memchr.sel2");
   }
 
-  if (!LenC)
+  if (!LenC) {
+    if (isOnlyUsedInEqualityComparison(CI, SrcStr))
+      // S is dereferenceable so it's safe to load from it and fold
+      //   memchr(S, C, N) == S to N && *S == C for any C and N.
+      // TODO: This is safe even even for nonconstant S.
+      return memChrToCharCompare(CI, Size, B, DL);
+
     // From now on we need a constant length and constant array.
     return nullptr;
+  }
 
   // If the char is variable but the input str and length are not we can turn
   // this memchr call into a simple bit field test. Of course this only works

diff  --git a/llvm/test/Transforms/InstCombine/memchr-11.ll b/llvm/test/Transforms/InstCombine/memchr-11.ll
index 7f67357f6449..e434b10d8c58 100644
--- a/llvm/test/Transforms/InstCombine/memchr-11.ll
+++ b/llvm/test/Transforms/InstCombine/memchr-11.ll
@@ -13,9 +13,9 @@ declare i8* @memchr(i8*, i32, i64)
 
 define i1 @fold_memchr_a_c_5_eq_a(i32 %c) {
 ; CHECK-LABEL: @fold_memchr_a_c_5_eq_a(
-; CHECK-NEXT:    [[Q:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0), i32 [[C:%.*]], i64 5)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[Q]], getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0)
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT:    [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], 49
+; CHECK-NEXT:    ret i1 [[CHAR0CMP]]
 ;
   %p = getelementptr [5 x i8], [5 x i8]* @a5, i32 0, i32 0
   %q = call i8* @memchr(i8* %p, i32 %c, i64 5)
@@ -30,9 +30,11 @@ define i1 @fold_memchr_a_c_5_eq_a(i32 %c) {
 
 define i1 @fold_memchr_a_c_n_eq_a(i32 %c, i64 %n) {
 ; CHECK-LABEL: @fold_memchr_a_c_n_eq_a(
-; CHECK-NEXT:    [[Q:%.*]] = call i8* @memchr(i8* getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0), i32 [[C:%.*]], i64 [[N:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[Q]], getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0)
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT:    [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], 49
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp ne i64 [[N:%.*]], 0
+; CHECK-NEXT:    [[TMP3:%.*]] = select i1 [[TMP2]], i1 [[CHAR0CMP]], i1 false
+; CHECK-NEXT:    ret i1 [[TMP3]]
 ;
   %p = getelementptr [5 x i8], [5 x i8]* @a5, i32 0, i32 0
   %q = call i8* @memchr(i8* %p, i32 %c, i64 %n)
@@ -61,9 +63,10 @@ define i1 @call_memchr_api_c_n_eq_a(i64 %i, i32 %c, i64 %n) {
 
 define i1 @fold_memchr_s_c_15_eq_s(i8* %s, i32 %c) {
 ; CHECK-LABEL: @fold_memchr_s_c_15_eq_s(
-; CHECK-NEXT:    [[P:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) [[S:%.*]], i32 [[C:%.*]], i64 15)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[P]], [[S]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    [[TMP1:%.*]] = load i8, i8* [[S:%.*]], align 1
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT:    [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    ret i1 [[CHAR0CMP]]
 ;
   %p = call i8* @memchr(i8* %s, i32 %c, i64 15)
   %cmp = icmp eq i8* %p, %s
@@ -75,9 +78,10 @@ define i1 @fold_memchr_s_c_15_eq_s(i8* %s, i32 %c) {
 
 define i1 @fold_memchr_s_c_17_neq_s(i8* %s, i32 %c) {
 ; CHECK-LABEL: @fold_memchr_s_c_17_neq_s(
-; CHECK-NEXT:    [[P:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) [[S:%.*]], i32 [[C:%.*]], i64 17)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i8* [[P]], [[S]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    [[TMP1:%.*]] = load i8, i8* [[S:%.*]], align 1
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT:    [[CHAR0CMP:%.*]] = icmp ne i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    ret i1 [[CHAR0CMP]]
 ;
   %p = call i8* @memchr(i8* %s, i32 %c, i64 17)
   %cmp = icmp ne i8* %p, %s
@@ -89,10 +93,10 @@ define i1 @fold_memchr_s_c_17_neq_s(i8* %s, i32 %c) {
 
 define i1 @fold_memchr_s_c_nz_eq_s(i8* %s, i32 %c, i64 %n) {
 ; CHECK-LABEL: @fold_memchr_s_c_nz_eq_s(
-; CHECK-NEXT:    [[NZ:%.*]] = or i64 [[N:%.*]], 1
-; CHECK-NEXT:    [[P:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) [[S:%.*]], i32 [[C:%.*]], i64 [[NZ]])
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[P]], [[S]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    [[TMP1:%.*]] = load i8, i8* [[S:%.*]], align 1
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT:    [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    ret i1 [[CHAR0CMP]]
 ;
   %nz = or i64 %n, 1
   %p = call i8* @memchr(i8* %s, i32 %c, i64 %nz)

diff  --git a/llvm/test/Transforms/InstCombine/strchr-4.ll b/llvm/test/Transforms/InstCombine/strchr-4.ll
index aa22d3e36d16..566059b47215 100644
--- a/llvm/test/Transforms/InstCombine/strchr-4.ll
+++ b/llvm/test/Transforms/InstCombine/strchr-4.ll
@@ -11,9 +11,10 @@ declare i8* @strchr(i8*, i32)
 
 define i1 @fold_strchr_s_c_eq_s(i8* %s, i32 %c) {
 ; CHECK-LABEL: @fold_strchr_s_c_eq_s(
-; CHECK-NEXT:    [[P:%.*]] = call i8* @strchr(i8* noundef nonnull dereferenceable(1) [[S:%.*]], i32 [[C:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[P]], [[S]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    [[TMP1:%.*]] = load i8, i8* [[S:%.*]], align 1
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT:    [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    ret i1 [[CHAR0CMP]]
 ;
   %p = call i8* @strchr(i8* %s, i32 %c)
   %cmp = icmp eq i8* %p, %s
@@ -25,9 +26,10 @@ define i1 @fold_strchr_s_c_eq_s(i8* %s, i32 %c) {
 
 define i1 @fold_strchr_s_c_neq_s(i8* %s, i32 %c) {
 ; CHECK-LABEL: @fold_strchr_s_c_neq_s(
-; CHECK-NEXT:    [[P:%.*]] = call i8* @strchr(i8* noundef nonnull dereferenceable(1) [[S:%.*]], i32 [[C:%.*]])
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i8* [[P]], [[S]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    [[TMP1:%.*]] = load i8, i8* [[S:%.*]], align 1
+; CHECK-NEXT:    [[TMP2:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT:    [[CHAR0CMP:%.*]] = icmp ne i8 [[TMP1]], [[TMP2]]
+; CHECK-NEXT:    ret i1 [[CHAR0CMP]]
 ;
   %p = call i8* @strchr(i8* %s, i32 %c)
   %cmp = icmp ne i8* %p, %s
@@ -40,8 +42,7 @@ define i1 @fold_strchr_s_c_neq_s(i8* %s, i32 %c) {
 
 define i1 @fold_strchr_s_nul_eqz(i8* %s) {
 ; CHECK-LABEL: @fold_strchr_s_nul_eqz(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[S:%.*]], null
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 false
 ;
   %p = call i8* @strchr(i8* %s, i32 0)
   %cmp = icmp eq i8* %p, null
@@ -53,8 +54,7 @@ define i1 @fold_strchr_s_nul_eqz(i8* %s) {
 
 define i1 @fold_strchr_s_nul_nez(i8* %s) {
 ; CHECK-LABEL: @fold_strchr_s_nul_nez(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i8* [[S:%.*]], null
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
   %p = call i8* @strchr(i8* %s, i32 0)
   %cmp = icmp ne i8* %p, null
@@ -68,9 +68,9 @@ define i1 @fold_strchr_s_nul_nez(i8* %s) {
 
 define i1 @fold_strchr_a_c_eq_a(i32 %c) {
 ; CHECK-LABEL: @fold_strchr_a_c_eq_a(
-; CHECK-NEXT:    [[MEMCHR:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0), i32 [[C:%.*]], i64 6)
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8* [[MEMCHR]], getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0)
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT:    [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], 49
+; CHECK-NEXT:    ret i1 [[CHAR0CMP]]
 ;
   %p = getelementptr [5 x i8], [5 x i8]* @a5, i32 0, i32 0
   %q = call i8* @strchr(i8* %p, i32 %c)


        


More information about the llvm-commits mailing list