[llvm] [Transforms] LoopIdiomRecognize recognize strlen and wcslen (PR #108985)

Michael Kruse via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 2 08:59:18 PDT 2024


================
@@ -1524,6 +1545,232 @@ static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry,
   return nullptr;
 }
 
+/// Recognizes a strlen idiom by checking for loops that increment
+/// a char pointer and then subtract with the base pointer.
+///
+/// If detected, transforms the relevant code to a strlen function
+/// call, and returns true; otherwise, returns false.
+///
+/// The core idiom we are trying to detect is:
+/// \code
+///     start = str;
+///     do {
+///       str++;
+///     } while(*str != '\0');
+/// \endcode
+///
+/// The transformed output is similar to below c-code:
+/// \code
+///     str = start + strlen(start)
+///     len = str - start
+/// \endcode
+///
+/// Later the pointer subtraction will be folded by InstCombine
+bool LoopIdiomRecognize::recognizeAndInsertStrLen() {
+  if (DisableLIRPStrlen)
+    return false;
+
+  // Give up if the loop has multiple blocks or multiple backedges.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1)
+    return false;
+
+  // It should have a preheader containing nothing but an unconditional branch.
+  auto *Preheader = CurLoop->getLoopPreheader();
+  if (!Preheader || &Preheader->front() != Preheader->getTerminator())
+    return false;
+
+  auto *EntryBI = dyn_cast<BranchInst>(Preheader->getTerminator());
+  if (!EntryBI || EntryBI->isConditional())
+    return false;
+
+  // The loop exit must be conditioned on an icmp with 0.
+  // The icmp operand has to be a load on some SSA reg that increments
+  // by 1 in the loop.
+  BasicBlock *LoopBody = *CurLoop->block_begin();
+  BranchInst *LoopTerm = dyn_cast<BranchInst>(LoopBody->getTerminator());
+  Value *LoopCond = matchCondition(LoopTerm, LoopBody);
+
+  if (!LoopCond)
+    return false;
+
+  auto *LoopLoad = dyn_cast<LoadInst>(LoopCond);
+  if (!LoopLoad || LoopLoad->getPointerAddressSpace() != 0)
+    return false;
+
+  Type *OperandType = LoopLoad->getType();
+  if (!OperandType || !OperandType->isIntegerTy())
+    return false;
+
+  // See if the pointer expression is an AddRec with step 1 ({n,+,1}) on
+  // the loop, indicating strlen calculation.
+  auto *IncPtr = LoopLoad->getPointerOperand();
+  const SCEVAddRecExpr *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(IncPtr));
+
+  if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
+    return false;
+
+  const SCEVConstant *Step =
+      dyn_cast<SCEVConstant>(LoadEv->getStepRecurrence(*SE));
+  if (!Step)
+    return false;
+
+  unsigned int StepSize = 0;
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(Step->getValue()))
+    StepSize = CI->getZExtValue();
+
+  unsigned OpWidth = OperandType->getIntegerBitWidth();
+  unsigned WcharSize = TLI->getWCharSize(*LoopLoad->getModule());
+  if (OpWidth != StepSize * 8)
+    return false;
+  if (OpWidth != 8 && OpWidth != 16 && OpWidth != 32)
+    return false;
+  if (OpWidth >= 16)
+    if (OpWidth != WcharSize * 8 || DisableLIRPWcslen)
+      return false;
+
+  // Scan every instruction in the loop to ensure there are no side effects.
+  for (auto &I : *LoopBody)
+    if (I.mayHaveSideEffects())
+      return false;
+
+  auto *LoopExitBB = CurLoop->getExitBlock();
+  if (!LoopExitBB)
+    return false;
+
+  // Check that the loop exit block is valid:
+  // It needs to have exactly one LCSSA Phi which is an AddRec.
+  PHINode *LCSSAPhi = nullptr;
+  for (PHINode &PN : LoopExitBB->phis()) {
+    if (!LCSSAPhi && PN.getNumIncomingValues() == 1)
+      LCSSAPhi = &PN;
+    else
+      return false;
+  }
+
+  if (!LCSSAPhi || !SE->isSCEVable(LCSSAPhi->getType()))
+    return false;
+
+  // This matched the pointer version of the idiom
+  if (LCSSAPhi->getIncomingValueForBlock(LoopBody) !=
+      LoopLoad->getPointerOperand())
+    return false;
+
+  const SCEVAddRecExpr *LCSSAEv =
+      dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LCSSAPhi->getIncomingValue(0)));
+
+  if (!LCSSAEv || !dyn_cast<SCEVUnknown>(SE->getPointerBase(LCSSAEv)) ||
+      !LCSSAEv->isAffine())
+    return false;
+
+  // We can now expand the base of the str
+  IRBuilder<> Builder(Preheader->getTerminator());
+
+  auto LoopPhiRange = LoopBody->phis();
+  if (!hasNItems(LoopPhiRange, 1))
+    return false;
+  auto *LoopPhi = &*LoopPhiRange.begin();
+  Value *PreVal = LoopPhi->getIncomingValueForBlock(Preheader);
+  if (!PreVal)
+    return false;
+
+  Value *Expanded = nullptr;
+  Type *ExpandedType = nullptr;
+  if (auto *GEP = dyn_cast<GetElementPtrInst>(LoopLoad->getPointerOperand())) {
+    if (GEP->getPointerOperand() != LoopPhi)
+      return false;
+    GetElementPtrInst *NewGEP = GetElementPtrInst::Create(
+        LoopLoad->getType(), PreVal, SmallVector<Value *, 4>(GEP->indices()),
+        "newgep", Preheader->getTerminator());
+    Expanded = NewGEP;
+    ExpandedType = LoopLoad->getType();
+  } else if (LoopLoad->getPointerOperand() == LoopPhi) {
+    Expanded = PreVal;
+    ExpandedType = LoopLoad->getType();
+  }
+  if (!Expanded)
+    return false;
+
+  // Ensure that the GEP has the correct index if the pointer was modified.
+  // This can happen when the pointer in the user code, outside the loop,
+  // walks past a certain pre-checked index of the string.
+  if (auto *GEP = dyn_cast<GEPOperator>(Expanded)) {
+    if (GEP->getNumOperands() != 2)
+      return false;
+
+    ConstantInt *I0 = dyn_cast<ConstantInt>(GEP->getOperand(1));
+    if (!I0)
+      return false;
+
+    int64_t Index = I0->getSExtValue(); // GEP index
+    auto *SAdd = dyn_cast<SCEVAddExpr>(LoadEv->getStart());
+    if (!SAdd || SAdd->getNumOperands() != 2)
+      return false;
+
+    auto *SAdd0 = dyn_cast<SCEVConstant>(SAdd->getOperand(0));
+    if (!SAdd0)
+      return false;
+
+    ConstantInt *CInt = SAdd0->getValue(); // SCEV index
+    assert(CInt && "Expecting CInt to be valid.");
+    int64_t Offset = CInt->getSExtValue();
+
+    // Update the index based on the Offset
+    assert((Offset * 8) % GEP->getSourceElementType()->getIntegerBitWidth() ==
+               0 &&
+           "Invalid offset");
+    int64_t NewIndex =
+        (Offset * 8) / GEP->getSourceElementType()->getIntegerBitWidth() -
+        Index;
+    Value *NewIndexVal =
+        ConstantInt::get(GEP->getOperand(1)->getType(), NewIndex);
+    GEP->setOperand(1, NewIndexVal);
+  }
+
+  Value *StrLenFunc = nullptr;
+  switch (OpWidth) {
+  case 8:
+    if (!TLI->has(LibFunc_strlen))
+      return false;
+    StrLenFunc = emitStrLen(Expanded, Builder, *DL, TLI);
+    break;
+  case 16:
+  case 32:
+    if (!TLI->has(LibFunc_wcslen))
+      return false;
+    StrLenFunc = emitWcsLen(Expanded, Builder, *DL, TLI);
+  }
+
+  assert(StrLenFunc && "Failed to emit strlen function.");
+
+  // Replace LCSSA Phi use with new pointer to the null terminator
+  SmallVector<Value *, 4> NewBaseIndex{StrLenFunc};
+  GetElementPtrInst *NewEndPtr = GetElementPtrInst::Create(
+      ExpandedType, Expanded, NewBaseIndex, "end", Preheader->getTerminator());
----------------
Meinersbur wrote:

```suggestion
  GetElementPtrInst *NewEndPtr = GetElementPtrInst::Create(
      ExpandedType, Expanded, ArrayRef(NewBaseIndex), "end", Preheader->getTerminator());
```
Also, prefer using the `IRBuilder::CreateGEP` function.

https://github.com/llvm/llvm-project/pull/108985


More information about the llvm-commits mailing list