[llvm] [AggressiveInstCombine] Inline strcmp/strncmp (PR #89371)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed May 1 19:23:26 PDT 2024


================
@@ -922,13 +914,239 @@ static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
   return true;
 }
 
+static cl::opt<unsigned> StrNCmpInlineThreshold(
+    "strncmp-inline-threshold", cl::init(3), cl::Hidden,
+    cl::desc("The maximum length of a constant string for a builtin string cmp "
+             "call eligible for inlining. The default value is 3."));
+
+namespace {
+class StrNCmpInliner {
+public:
+  StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU,
+                 const DataLayout &DL)
+      : CI(CI), Func(Func), DTU(DTU), DL(DL) {}
+
+  bool optimizeStrNCmp();
+
+private:
+  bool inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped);
+
+  CallInst *CI;
+  LibFunc Func;
+  DomTreeUpdater *DTU;
+  const DataLayout &DL;
+};
+
+} // namespace
+
+/// First we normalize calls to strncmp/strcmp to the form of
+/// compare(s1, s2, N), which means comparing first N bytes of s1 and s2
+/// (without considering '\0').
+///
+/// Examples:
+///
+/// \code
+///   strncmp(s, "a", 3) -> compare(s, "a", 2)
+///   strncmp(s, "abc", 3) -> compare(s, "abc", 3)
+///   strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2)
+///   strcmp(s, "a") -> compare(s, "a", 2)
+///
+///   char s2[] = {'a'}
+///   strncmp(s, s2, 3) -> compare(s, s2, 3)
+///
+///   char s2[] = {'a', 'b', 'c', 'd'}
+///   strncmp(s, s2, 3) -> compare(s, s2, 3)
+/// \endcode
+///
+/// We only handle cases where N and exactly one of s1 and s2 are constant.
+/// Cases that s1 and s2 are both constant are already handled by the
+/// instcombine pass.
+///
+/// We do not handle cases where N > StrNCmpInlineThreshold.
+///
+/// We also do not handles cases where N < 2, which are already
+/// handled by the instcombine pass.
+///
+bool StrNCmpInliner::optimizeStrNCmp() {
+  if (StrNCmpInlineThreshold < 2)
+    return false;
+
+  if (!isOnlyUsedInZeroComparison(CI))
+    return false;
+
+  Value *Str1P = CI->getArgOperand(0);
+  Value *Str2P = CI->getArgOperand(1);
+  // Should be handled elsewhere.
+  if (Str1P == Str2P)
+    return false;
+
+  StringRef Str1, Str2;
+  bool HasStr1 = getConstantStringInfo(Str1P, Str1, false);
+  bool HasStr2 = getConstantStringInfo(Str2P, Str2, false);
+  if (HasStr1 == HasStr2)
+    return false;
+
+  // Note that '\0' and characters after it are not trimmed.
+  StringRef Str = HasStr1 ? Str1 : Str2;
+
+  size_t Idx = Str.find('\0');
+  uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1;
+  if (Func == LibFunc_strncmp) {
+    if (auto ConstInt = dyn_cast<ConstantInt>(CI->getArgOperand(2)))
+      N = std::min(N, ConstInt->getZExtValue());
+    else
+      return false;
+  }
+  // Now N means how many bytes we need to compare at most.
+  if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold)
+    return false;
+
+  Value *StrP = HasStr1 ? Str2P : Str1P;
+
+  // Cases where StrP has two or more dereferenceable bytes might be better
+  // optimized elsewhere.
+  bool CanBeNull = false, CanBeFreed = false;
+  if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1)
+    return false;
+
+  return inlineCompare(StrP, Str, N, HasStr1);
+}
+
+/// Convert
+///
+/// \code
+///   ret = compare(s1, s2, N)
+/// \endcode
+///
+/// into
+///
+/// \code
+///   ret = (int)s1[0] - (int)s2[0]
+///   if (ret != 0)
+///     goto NE
+///   ...
+///   ret = (int)s1[N-2] - (int)s2[N-2]
+///   if (ret != 0)
+///     goto NE
+///   ret = (int)s1[N-1] - (int)s2[N-1]
+///   NE:
+/// \endcode
+///
+/// CFG before and after the transformation:
+///
+/// (before)
+/// BBCI
+///
+/// (after)
+/// BBCI -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBTail
+///                 |                    ^
+///                 E                    |
+///                 |                    |
+///        BBSubs[1] (sub,icmp) --NE-----+
+///                ...                   |
+///        BBSubs[N-1]    (sub) ---------+
+///
+bool StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
+                                   bool Swapped) {
+  auto &Ctx = CI->getContext();
+  IRBuilder<> B(Ctx);
+
+  BasicBlock *BBCI = CI->getParent();
+  BasicBlock *BBTail =
+      SplitBlock(BBCI, CI, DTU, nullptr, nullptr, BBCI->getName() + ".tail");
+
+  SmallVector<BasicBlock *> BBSubs;
+  for (uint64_t i = 0; i < N; ++i)
+    BBSubs.push_back(BasicBlock::Create(Ctx, "sub_" + std::to_string(i),
+                                        BBCI->getParent(), BBTail));
+  BasicBlock *BBNE = BasicBlock::Create(Ctx, "ne", BBCI->getParent(), BBTail);
+
+  cast<BranchInst>(BBCI->getTerminator())->setSuccessor(0, BBSubs[0]);
+
+  B.SetInsertPoint(BBNE);
+  PHINode *Phi = B.CreatePHI(CI->getType(), N);
+  B.CreateBr(BBTail);
+
+  Value *Base = LHS;
+  for (uint64_t i = 0; i < N; ++i) {
+    B.SetInsertPoint(BBSubs[i]);
+    Value *VL =
+        B.CreateZExt(B.CreateLoad(B.getInt8Ty(),
+                                  B.CreateInBoundsPtrAdd(Base, B.getInt64(i))),
+                     CI->getType());
+    Value *VR = ConstantInt::get(CI->getType(), RHS[i]);
+    Value *Sub = Swapped ? B.CreateSub(VR, VL) : B.CreateSub(VL, VR);
+    if (i < N - 1)
+      B.CreateCondBr(B.CreateICmpNE(Sub, ConstantInt::get(CI->getType(), 0)),
+                     BBNE, BBSubs[i + 1]);
+    else
+      B.CreateBr(BBNE);
+
+    Phi->addIncoming(Sub, BBSubs[i]);
+  }
+
+  CI->replaceAllUsesWith(Phi);
+  CI->eraseFromParent();
+
+  if (DTU) {
+    SmallVector<DominatorTree::UpdateType, 8> Updates;
+    Updates.push_back({DominatorTree::Insert, BBCI, BBSubs[0]});
+    for (uint64_t i = 0; i < N; ++i) {
+      if (i < N - 1)
+        Updates.push_back({DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]});
+      Updates.push_back({DominatorTree::Insert, BBSubs[i], BBNE});
+    }
+    Updates.push_back({DominatorTree::Insert, BBNE, BBTail});
+    Updates.push_back({DominatorTree::Delete, BBCI, BBTail});
+    DTU->applyUpdates(Updates);
+  }
+  return true;
+}
+
+static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
+                         TargetLibraryInfo &TLI, llvm::AssumptionCache &AC,
----------------
nikic wrote:

```suggestion
                         TargetLibraryInfo &TLI, AssumptionCache &AC,
```

https://github.com/llvm/llvm-project/pull/89371


More information about the llvm-commits mailing list