[llvm] [Transforms] LoopIdiomRecognize recognize strlen and wcslen (PR #108985)
Michael Kruse via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 2 08:59:18 PDT 2024
================
@@ -1524,6 +1545,232 @@ static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry,
return nullptr;
}
+/// Recognizes a strlen idiom by checking for loops that increment
+/// a char pointer and then subtract with the base pointer.
+///
+/// If detected, transforms the relevant code to a strlen function
+/// call, and returns true; otherwise, returns false.
+///
+/// The core idiom we are trying to detect is:
+/// \code
+/// start = str;
+/// do {
+/// str++;
+/// } while(*str != '\0');
+/// \endcode
+///
+/// The transformed output is similar to below c-code:
+/// \code
+/// str = start + strlen(start)
+/// len = str - start
+/// \endcode
+///
+/// Later the pointer subtraction will be folded by InstCombine
+bool LoopIdiomRecognize::recognizeAndInsertStrLen() {
+ if (DisableLIRPStrlen)
+ return false;
+
+ // Give up if the loop has multiple blocks or multiple backedges.
+ if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
+ return false;
+
+ // It should have a preheader containing nothing but an unconditional branch.
+ auto *Preheader = CurLoop->getLoopPreheader();
+ if (!Preheader || &Preheader->front() != Preheader->getTerminator())
+ return false;
+
+ auto *EntryBI = dyn_cast<BranchInst>(Preheader->getTerminator());
+ if (!EntryBI || EntryBI->isConditional())
+ return false;
+
+ // The loop exit must be conditioned on an icmp with 0.
+ // The icmp operand has to be a load on some SSA reg that increments
+ // by 1 in the loop.
+ BasicBlock *LoopBody = *CurLoop->block_begin();
+ BranchInst *LoopTerm = dyn_cast<BranchInst>(LoopBody->getTerminator());
+ Value *LoopCond = matchCondition(LoopTerm, LoopBody);
+
+ if (!LoopCond)
+ return false;
+
+ auto *LoopLoad = dyn_cast<LoadInst>(LoopCond);
+ if (!LoopLoad || LoopLoad->getPointerAddressSpace() != 0)
+ return false;
+
+ Type *OperandType = LoopLoad->getType();
+ if (!OperandType || !OperandType->isIntegerTy())
+ return false;
+
+ // See if the pointer expression is an AddRec with step 1 ({n,+,1}) on
+ // the loop, indicating strlen calculation.
+ auto *IncPtr = LoopLoad->getPointerOperand();
+ const SCEVAddRecExpr *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(IncPtr));
+
+ if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
+ return false;
+
+ const SCEVConstant *Step =
+ dyn_cast<SCEVConstant>(LoadEv->getStepRecurrence(*SE));
+ if (!Step)
+ return false;
+
+ unsigned int StepSize = 0;
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(Step->getValue()))
+ StepSize = CI->getZExtValue();
+
+ unsigned OpWidth = OperandType->getIntegerBitWidth();
+ unsigned WcharSize = TLI->getWCharSize(*LoopLoad->getModule());
+ if (OpWidth != StepSize * 8)
+ return false;
+ if (OpWidth != 8 && OpWidth != 16 && OpWidth != 32)
+ return false;
+ if (OpWidth >= 16)
+ if (OpWidth != WcharSize * 8 || DisableLIRPWcslen)
+ return false;
+
+ // Scan every instruction in the loop to ensure there are no side effects.
+ for (auto &I : *LoopBody)
+ if (I.mayHaveSideEffects())
+ return false;
+
+ auto *LoopExitBB = CurLoop->getExitBlock();
+ if (!LoopExitBB)
+ return false;
+
+ // Check that the loop exit block is valid:
+ // It needs to have exactly one LCSSA Phi which is an AddRec.
+ PHINode *LCSSAPhi = nullptr;
+ for (PHINode &PN : LoopExitBB->phis()) {
+ if (!LCSSAPhi && PN.getNumIncomingValues() == 1)
+ LCSSAPhi = &PN;
+ else
+ return false;
+ }
+
+ if (!LCSSAPhi || !SE->isSCEVable(LCSSAPhi->getType()))
+ return false;
+
+ // This matched the pointer version of the idiom
+ if (LCSSAPhi->getIncomingValueForBlock(LoopBody) !=
+ LoopLoad->getPointerOperand())
+ return false;
+
+ const SCEVAddRecExpr *LCSSAEv =
+ dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LCSSAPhi->getIncomingValue(0)));
+
+ if (!LCSSAEv || !dyn_cast<SCEVUnknown>(SE->getPointerBase(LCSSAEv)) ||
+ !LCSSAEv->isAffine())
+ return false;
+
+ // We can now expand the base of the str
+ IRBuilder<> Builder(Preheader->getTerminator());
+
+ auto LoopPhiRange = LoopBody->phis();
+ if (!hasNItems(LoopPhiRange, 1))
+ return false;
+ auto *LoopPhi = &*LoopPhiRange.begin();
+ Value *PreVal = LoopPhi->getIncomingValueForBlock(Preheader);
+ if (!PreVal)
+ return false;
+
+ Value *Expanded = nullptr;
+ Type *ExpandedType = nullptr;
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(LoopLoad->getPointerOperand())) {
+ if (GEP->getPointerOperand() != LoopPhi)
+ return false;
+ GetElementPtrInst *NewGEP = GetElementPtrInst::Create(
+ LoopLoad->getType(), PreVal, SmallVector<Value *, 4>(GEP->indices()),
+ "newgep", Preheader->getTerminator());
+ Expanded = NewGEP;
+ ExpandedType = LoopLoad->getType();
+ } else if (LoopLoad->getPointerOperand() == LoopPhi) {
+ Expanded = PreVal;
+ ExpandedType = LoopLoad->getType();
+ }
+ if (!Expanded)
+ return false;
+
+ // Ensure that the GEP has the correct index if the pointer was modified.
+ // This can happen when the pointer in the user code, outside the loop,
+ // walks past a certain pre-checked index of the string.
+ if (auto *GEP = dyn_cast<GEPOperator>(Expanded)) {
+ if (GEP->getNumOperands() != 2)
+ return false;
+
+ ConstantInt *I0 = dyn_cast<ConstantInt>(GEP->getOperand(1));
+ if (!I0)
+ return false;
+
+ int64_t Index = I0->getSExtValue(); // GEP index
+ auto *SAdd = dyn_cast<SCEVAddExpr>(LoadEv->getStart());
+ if (!SAdd || SAdd->getNumOperands() != 2)
+ return false;
+
+ auto *SAdd0 = dyn_cast<SCEVConstant>(SAdd->getOperand(0));
+ if (!SAdd0)
+ return false;
+
+ ConstantInt *CInt = SAdd0->getValue(); // SCEV index
+ assert(CInt && "Expecting CInt to be valid.");
+ int64_t Offset = CInt->getSExtValue();
+
+ // Update the index based on the Offset
+ assert((Offset * 8) % GEP->getSourceElementType()->getIntegerBitWidth() ==
+ 0 &&
+ "Invalid offset");
+ int64_t NewIndex =
+ (Offset * 8) / GEP->getSourceElementType()->getIntegerBitWidth() -
+ Index;
+ Value *NewIndexVal =
+ ConstantInt::get(GEP->getOperand(1)->getType(), NewIndex);
+ GEP->setOperand(1, NewIndexVal);
+ }
+
+ Value *StrLenFunc = nullptr;
+ switch (OpWidth) {
+ case 8:
+ if (!TLI->has(LibFunc_strlen))
+ return false;
+ StrLenFunc = emitStrLen(Expanded, Builder, *DL, TLI);
+ break;
+ case 16:
+ case 32:
+ if (!TLI->has(LibFunc_wcslen))
+ return false;
+ StrLenFunc = emitWcsLen(Expanded, Builder, *DL, TLI);
+ }
+
+ assert(StrLenFunc && "Failed to emit strlen function.");
+
+ // Replace LCSSA Phi use with new pointer to the null terminator
+ SmallVector<Value *, 4> NewBaseIndex{StrLenFunc};
+ GetElementPtrInst *NewEndPtr = GetElementPtrInst::Create(
+ ExpandedType, Expanded, NewBaseIndex, "end", Preheader->getTerminator());
+ LCSSAPhi->replaceAllUsesWith(NewEndPtr);
+ RecursivelyDeleteDeadPHINode(LCSSAPhi);
+
+ ConstantInt *NewLoopCond = LoopTerm->getSuccessor(0) == LoopBody
+ ? Builder.getFalse()
+ : Builder.getTrue();
+ LoopTerm->setCondition(NewLoopCond);
+
+ deleteDeadInstruction(cast<Instruction>(LoopCond));
+ deleteDeadInstruction(cast<Instruction>(IncPtr));
+ SE->forgetLoop(CurLoop);
+
+ LLVM_DEBUG(dbgs() << " Formed strlen: " << *StrLenFunc << "\n");
+
+ ORE.emit([&]() {
+ return OptimizationRemark(DEBUG_TYPE, "recognizeAndInsertStrLen",
+ CurLoop->getStartLoc(), Preheader)
+ << "Transformed pointer difference into a call to strlen() function";
----------------
Meinersbur wrote:
Not just the pointer difference was transformed, the entire loop was.
Also: Could have been the `wcslen()` function
https://github.com/llvm/llvm-project/pull/108985
More information about the llvm-commits
mailing list