[llvm] [AggressiveInstCombine] Inline strcmp/strncmp (PR #89371)
Franklin Zhang via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 30 21:14:25 PDT 2024
================
@@ -922,6 +924,239 @@ static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
return true;
}
+static cl::opt<unsigned> StrNCmpInlineThreshold(
+ "strncmp-inline-threshold", cl::init(3), cl::Hidden,
+ cl::desc("The maximum length of a constant string for a builtin string cmp "
+ "call eligible for inlining. The default value is 3."));
+
+namespace {
+class StrNCmpInliner {
+public:
+ StrNCmpInliner(CallInst *CI, LibFunc Func, Function::iterator &BBNext,
+ DomTreeUpdater *DTU, const DataLayout &DL)
+ : CI(CI), Func(Func), BBNext(BBNext), DTU(DTU), DL(DL) {}
+
+ bool optimizeStrNCmp();
+
+private:
+ bool inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Switched);
+
+ CallInst *CI;
+ LibFunc Func;
+ Function::iterator &BBNext;
+ DomTreeUpdater *DTU;
+ const DataLayout &DL;
+};
+
+} // namespace
+
+/// First we normalize calls to strncmp/strcmp to the form of
+/// compare(s1, s2, N), which means comparing first N bytes of s1 and s2
+/// (without considering '\0')
+///
+/// Examples:
+///
+/// \code
+/// strncmp(s, "a", 3) -> compare(s, "a", 2)
+/// strncmp(s, "abc", 3) -> compare(s, "abc", 3)
+/// strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2)
+/// strcmp(s, "a") -> compare(s, "a", 2)
+///
+/// char s2[] = {'a'}
+/// strncmp(s, s2, 3) -> compare(s, s2, 3)
+///
+/// char s2[] = {'a', 'b', 'c', 'd'}
+/// strncmp(s, s2, 3) -> compare(s, s2, 3)
+/// \endcode
+///
+/// We only handle cases that N and exactly one of s1 and s2 are constant. Cases
+/// that s1 and s2 are both constant are already handled by the instcombine
+/// pass.
+///
+/// We do not handle cases that N > StrNCmpInlineThreshold.
+///
+/// We also do not handles cases that N < 2, which are already
+/// handled by the instcombine pass.
+///
+bool StrNCmpInliner::optimizeStrNCmp() {
+ if (StrNCmpInlineThreshold < 2)
+ return false;
+
+ Value *Str1P = CI->getArgOperand(0);
+ Value *Str2P = CI->getArgOperand(1);
+ // should be handled elsewhere
+ if (Str1P == Str2P)
+ return false;
+
+ StringRef Str1, Str2;
+ bool HasStr1 = getConstantStringInfo(Str1P, Str1, false);
+ bool HasStr2 = getConstantStringInfo(Str2P, Str2, false);
+ if (!(HasStr1 ^ HasStr2))
+ return false;
+
+ // note that '\0' and characters after it are not trimmed
+ StringRef Str = HasStr1 ? Str1 : Str2;
+
+ size_t Idx = Str.find('\0');
+ uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1;
+ if (Func == LibFunc_strncmp) {
+ if (!isa<ConstantInt>(CI->getArgOperand(2)))
+ return false;
+ N = std::min(N, cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue());
+ }
+ // now N means how many bytes we need to compare at most
+ if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold)
+ return false;
+
+ Value *StrP = HasStr1 ? Str2P : Str1P;
+
+ // cases that StrP has two or more dereferenceable bytes might be better
+ // optimized elsewhere
+ bool CanBeNull = false, CanBeFreed = false;
+ if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1)
+ return false;
+
+ return inlineCompare(StrP, Str, N, HasStr1);
+}
+
+/// Convert
+///
+/// \code
+/// ret = compare(s1, s2, N)
+/// \endcode
+///
+/// into
+///
+/// \code
+/// ret = (int)s1[0] - (int)s2[0]
+/// if (ret != 0)
+/// goto NE
+/// ...
+/// ret = (int)s1[N-2] - (int)s2[N-2]
+/// if (ret != 0)
+/// goto NE
+/// ret = (int)s1[N-1] - (int)s2[N-1]
+/// NE:
+/// \endcode
+///
+/// CFG before and after the transformation:
+///
+/// (before)
+/// BBCI
+///
+/// (after)
+/// BBBefore -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBCI
+/// | ^
+/// E |
+/// | |
+/// BBSubs[1] (sub,icmp) --NE-----+
+/// ... |
+/// BBSubs[N-1] (sub) ---------+
+///
+bool StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
+ bool Switched) {
+ auto &Ctx = CI->getContext();
+ IRBuilder<> B(Ctx);
+
+ BasicBlock *BBCI = CI->getParent();
+ BasicBlock *BBBefore = splitBlockBefore(BBCI, CI, DTU, nullptr, nullptr,
+ BBCI->getName() + ".before");
+
+ SmallVector<BasicBlock *> BBSubs;
+ for (uint64_t i = 0; i < N; ++i)
+ BBSubs.push_back(BasicBlock::Create(Ctx, "sub", BBCI->getParent(), BBCI));
+ BasicBlock *BBNE = BasicBlock::Create(Ctx, "ne", BBCI->getParent(), BBCI);
+
+ cast<BranchInst>(BBBefore->getTerminator())->setSuccessor(0, BBSubs[0]);
+
+ B.SetInsertPoint(BBNE);
+ PHINode *Phi = B.CreatePHI(CI->getType(), N);
+ B.CreateBr(BBCI);
+
+ Value *Base = LHS;
+ for (uint64_t i = 0; i < N; ++i) {
+ B.SetInsertPoint(BBSubs[i]);
+ Value *VL = B.CreateZExt(
+ B.CreateLoad(B.getInt8Ty(),
+ B.CreateInBoundsGEP(B.getInt8Ty(), Base, B.getInt64(i))),
----------------
FLZ101 wrote:
done.
https://github.com/llvm/llvm-project/pull/89371
More information about the llvm-commits
mailing list