[llvm] [NFC][AArch64] Refactor AArch64LoopIdiomTransform in preparation for more idioms (PR #78471)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 17 08:46:15 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: David Sherwood (david-arm)

<details>
<summary>Changes</summary>

In a future patch I intend to add support for other types of C loops that do memory comparisons with early exits, e.g.

  for (unsigned long i = start; i < end; i++) {
    if (p1[i] != p2[i])
      break;
  }

where we compare the loaded values prior to incrementing the induction variable `i`. This requires first refactoring the pass ready to support the new loop structures. I've created a new MemCompareIdiom class with support routines that can be reused by new code required to recognise these new loops.

I also modified MemCompareIdiom::generateMemCompare to ensure it is ready to support arbitrary induction variable types as well as supporting new comparison predicates when comparing the memory values.

---

Patch is 20.01 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78471.diff


1 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp (+214-148) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp b/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
index 6dfb2b9df7135d..5b48d5be159f75 100644
--- a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
+++ b/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
@@ -1,3 +1,4 @@
+
 //===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -97,14 +98,6 @@ class AArch64LoopIdiomTransform {
   bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
                       SmallVectorImpl<BasicBlock *> &ExitBlocks);
 
-  bool recognizeByteCompare();
-  Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
-                            GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
-                            Instruction *Index, Value *Start, Value *MaxLen);
-  void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
-                            PHINode *IndPhi, Value *MaxLen, Instruction *Index,
-                            Value *Start, bool IncIdx, BasicBlock *FoundBB,
-                            BasicBlock *EndBB);
   /// @}
 };
 
@@ -187,6 +180,38 @@ AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
 //
 //===----------------------------------------------------------------------===//
 
+struct MemCompareIdiom {
+private:
+  const TargetTransformInfo *TTI;
+  DominatorTree *DT;
+  LoopInfo *LI;
+  Loop *CurLoop;
+  GetElementPtrInst *GEPA, *GEPB;
+  PHINode *IndPhi;
+  Instruction *Index;
+  bool IncIdx;
+  Value *StartIdx, *MaxLen;
+  BasicBlock *FoundBB, *EndBB;
+  ICmpInst::Predicate CmpPred;
+
+public:
+  MemCompareIdiom(const TargetTransformInfo *TTI, DominatorTree *DT,
+                  LoopInfo *LI, Loop *L)
+      : TTI(TTI), DT(DT), LI(LI), CurLoop(L) {}
+
+  bool recognize();
+  void transform();
+
+private:
+  bool checkIterationCondition(BasicBlock *CheckItBB, BasicBlock *&LoopBB);
+  bool checkLoadCompareCondition(BasicBlock *LoadCmpBB, Value *&LoadA,
+                                 Value *&LoadB);
+  bool areValidLoads(Value *BaseIdx, Value *LoadA, Value *LoadB);
+  bool checkEndAndFoundBlockPhis(Value *IndVal);
+  bool recognizePostIncMemCompare();
+  Value *generateMemCompare(IRBuilder<> &Builder, DomTreeUpdater &DTU);
+};
+
 bool AArch64LoopIdiomTransform::run(Loop *L) {
   CurLoop = L;
 
@@ -202,32 +227,135 @@ bool AArch64LoopIdiomTransform::run(Loop *L) {
                     << CurLoop->getHeader()->getParent()->getName()
                     << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
 
-  return recognizeByteCompare();
+  MemCompareIdiom BCI(TTI, DT, LI, L);
+  if (BCI.recognize()) {
+    BCI.transform();
+    return true;
+  }
+
+  return false;
 }
 
-bool AArch64LoopIdiomTransform::recognizeByteCompare() {
-  // Currently the transformation only works on scalable vector types, although
-  // there is no fundamental reason why it cannot be made to work for fixed
-  // width too.
+bool MemCompareIdiom::areValidLoads(Value *BaseIdx, Value *LoadA,
+                                    Value *LoadB) {
+  Value *A, *B;
+  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+    return false;
 
-  // We also need to know the minimum page size for the target in order to
-  // generate runtime memory checks to ensure the vector version won't fault.
-  if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
-      DisableByteCmp)
+  if (!cast<LoadInst>(LoadA)->isSimple() || !cast<LoadInst>(LoadB)->isSimple())
     return false;
 
-  BasicBlock *Header = CurLoop->getHeader();
+  GEPA = dyn_cast<GetElementPtrInst>(A);
+  GEPB = dyn_cast<GetElementPtrInst>(B);
 
-  // In AArch64LoopIdiomTransform::run we have already checked that the loop
-  // has a preheader so we can assume it's in a canonical form.
-  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
+  if (!GEPA || !GEPB)
     return false;
 
-  PHINode *PN = dyn_cast<PHINode>(&Header->front());
-  if (!PN || PN->getNumIncomingValues() != 2)
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // Check we are loading i8 values from two loop invariant pointers
+  Type *MemType = GEPA->getResultElementType();
+  if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
+      !MemType->isIntegerTy(8) || GEPB->getResultElementType() != MemType ||
+      cast<LoadInst>(LoadA)->getType() != MemType ||
+      cast<LoadInst>(LoadB)->getType() != MemType || PtrA == PtrB)
+    return false;
+
+  // Check that the index to the GEPs is the index we found earlier
+  if (GEPA->getNumIndices() != 1 || GEPB->getNumIndices() != 1)
+    return false;
+
+  Value *IdxA = GEPA->getOperand(1);
+  Value *IdxB = GEPB->getOperand(1);
+
+  if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(BaseIdx))))
     return false;
 
+  return true;
+}
+
+static bool checkBlockSizes(Loop *CurLoop, unsigned Block1Size,
+                            unsigned Block2Size) {
   auto LoopBlocks = CurLoop->getBlocks();
+
+  auto BB1 = LoopBlocks[0]->instructionsWithoutDebug();
+  if (std::distance(BB1.begin(), BB1.end()) > Block1Size)
+    return false;
+
+  auto BB2 = LoopBlocks[1]->instructionsWithoutDebug();
+  if (std::distance(BB2.begin(), BB2.end()) > Block2Size)
+    return false;
+
+  return true;
+}
+
+bool MemCompareIdiom::checkIterationCondition(BasicBlock *CheckItBB,
+                                              BasicBlock *&LoopBB) {
+  ICmpInst::Predicate Pred;
+  if (!match(CheckItBB->getTerminator(),
+             m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
+                  m_BasicBlock(EndBB), m_BasicBlock(LoopBB))) ||
+      Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(LoopBB))
+    return false;
+
+  return true;
+}
+
+bool MemCompareIdiom::checkLoadCompareCondition(BasicBlock *LoadCmpBB,
+                                                Value *&LoadA, Value *&LoadB) {
+  BasicBlock *TrueBB;
+  ICmpInst::Predicate Pred;
+  // TODO: Support other predicates.
+  CmpPred = ICmpInst::Predicate::ICMP_EQ;
+  if (!match(LoadCmpBB->getTerminator(),
+             m_Br(m_ICmp(Pred, m_Value(LoadA), m_Value(LoadB)),
+                  m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
+      Pred != CmpPred || TrueBB != CurLoop->getHeader())
+    return false;
+
+  return true;
+}
+
+bool MemCompareIdiom::checkEndAndFoundBlockPhis(Value *IndVal) {
+  // Ensure that when the Found and End blocks are identical the PHIs have the
+  // supported format. We don't currently allow cases like this:
+  // while.cond:
+  //   ...
+  //   br i1 %cmp.not, label %while.end, label %while.body
+  //
+  // while.body:
+  //   ...
+  //   br i1 %cmp.not2, label %while.cond, label %while.end
+  //
+  // while.end:
+  //   %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
+  //
+  // Where the incoming values for %final_ptr are unique and from each of the
+  // loop blocks, but not actually defined in the loop. This requires extra
+  // work setting up the byte.compare block, i.e. by introducing a select to
+  // choose the correct value.
+  // TODO: We could add support for this in future.
+  if (FoundBB == EndBB) {
+    auto LoopBlocks = CurLoop->getBlocks();
+    for (PHINode &EndPN : EndBB->phis()) {
+      Value *V1 = EndPN.getIncomingValueForBlock(LoopBlocks[0]);
+      Value *V2 = EndPN.getIncomingValueForBlock(LoopBlocks[1]);
+
+      // The value of the index when leaving the while.cond block is always the
+      // same as the end value (MaxLen) so we permit either. Otherwise for any
+      // other value defined outside the loop we only allow values that are the
+      // same as the exit value for while.body.
+      if (V1 != V2 &&
+          ((V1 != IndVal && V1 != MaxLen) || (V2 != IndVal)))
+        return false;
+    }
+  }
+
+  return true;
+}
+
+bool MemCompareIdiom::recognizePostIncMemCompare() {
   // The first block in the loop should contain only 4 instructions, e.g.
   //
   //  while.cond:
@@ -236,10 +364,6 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   //   %cmp.not = icmp eq i32 %inc, %n
   //   br i1 %cmp.not, label %while.end, label %while.body
   //
-  auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
-  if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4)
-    return false;
-
   // The second block should contain 7 instructions, e.g.
   //
   // while.body:
@@ -251,145 +375,90 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
   //   br i1 %cmp.not.ld, label %while.cond, label %while.end
   //
-  auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
-  if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
+  if (!checkBlockSizes(CurLoop, 4, 7))
     return false;
 
-  // The incoming value to the PHI node from the loop should be an add of 1.
-  Value *StartIdx = nullptr;
-  Instruction *Index = nullptr;
-  if (!CurLoop->contains(PN->getIncomingBlock(0))) {
-    StartIdx = PN->getIncomingValue(0);
-    Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
-  } else {
-    StartIdx = PN->getIncomingValue(1);
-    Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
-  }
-
-  // Limit to 32-bit types for now
-  if (!Index || !Index->getType()->isIntegerTy(32) ||
-      !match(Index, m_c_Add(m_Specific(PN), m_One())))
-    return false;
-
-  // If we match the pattern, PN and Index will be replaced with the result of
-  // the cttz.elts intrinsic. If any other instructions are used outside of
-  // the loop, we cannot replace it.
-  for (BasicBlock *BB : LoopBlocks)
-    for (Instruction &I : *BB)
-      if (&I != PN && &I != Index)
-        for (User *U : I.users())
-          if (!CurLoop->contains(cast<Instruction>(U)))
-            return false;
-
   // Match the branch instruction for the header
-  ICmpInst::Predicate Pred;
-  Value *MaxLen;
-  BasicBlock *EndBB, *WhileBB;
-  if (!match(Header->getTerminator(),
-             m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
-                  m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
-      Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB))
+  BasicBlock *Header = CurLoop->getHeader();
+  BasicBlock *WhileBodyBB;
+  if (!checkIterationCondition(Header, WhileBodyBB))
     return false;
 
   // WhileBB should contain the pattern of load & compare instructions. Match
   // the pattern and find the GEP instructions used by the loads.
-  ICmpInst::Predicate WhilePred;
-  BasicBlock *FoundBB;
-  BasicBlock *TrueBB;
   Value *LoadA, *LoadB;
-  if (!match(WhileBB->getTerminator(),
-             m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)),
-                  m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
-      WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB))
+  if (!checkLoadCompareCondition(WhileBodyBB, LoadA, LoadB))
     return false;
 
-  Value *A, *B;
-  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+  if (!areValidLoads(Index, LoadA, LoadB))
     return false;
 
-  LoadInst *LoadAI = cast<LoadInst>(LoadA);
-  LoadInst *LoadBI = cast<LoadInst>(LoadB);
-  if (!LoadAI->isSimple() || !LoadBI->isSimple())
+  if (!checkEndAndFoundBlockPhis(Index))
     return false;
 
-  GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
-  GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
+  return true;
+}
 
-  if (!GEPA || !GEPB)
+bool MemCompareIdiom::recognize() {
+  // Currently the transformation only works on scalable vector types, although
+  // there is no fundamental reason why it cannot be made to work for fixed
+  // width too.
+
+  // We also need to know the minimum page size for the target in order to
+  // generate runtime memory checks to ensure the vector version won't fault.
+  if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
+      DisableByteCmp)
     return false;
 
-  Value *PtrA = GEPA->getPointerOperand();
-  Value *PtrB = GEPB->getPointerOperand();
+  BasicBlock *Header = CurLoop->getHeader();
 
-  // Check we are loading i8 values from two loop invariant pointers
-  if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
-      !GEPA->getResultElementType()->isIntegerTy(8) ||
-      !GEPB->getResultElementType()->isIntegerTy(8) ||
-      !LoadAI->getType()->isIntegerTy(8) ||
-      !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB)
+  // In AArch64LoopIdiomTransform::run we have already checked that the loop
+  // has a preheader so we can assume it's in a canonical form.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
     return false;
 
-  // Check that the index to the GEPs is the index we found earlier
-  if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
+  // We only expect one PHI node for the index.
+  IndPhi = dyn_cast<PHINode>(&Header->front());
+  if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
     return false;
 
-  Value *IdxA = GEPA->getOperand(GEPA->getNumIndices());
-  Value *IdxB = GEPB->getOperand(GEPB->getNumIndices());
-  if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index))))
-    return false;
+  // The incoming value to the PHI node from the loop should be an add of 1.
+  StartIdx = nullptr;
+  Index = nullptr;
+  if (!CurLoop->contains(IndPhi->getIncomingBlock(0))) {
+    StartIdx = IndPhi->getIncomingValue(0);
+    Index = dyn_cast<Instruction>(IndPhi->getIncomingValue(1));
+  } else {
+    StartIdx = IndPhi->getIncomingValue(1);
+    Index = dyn_cast<Instruction>(IndPhi->getIncomingValue(0));
+  }
 
-  // We only ever expect the pre-incremented index value to be used inside the
-  // loop.
-  if (!PN->hasOneUse())
+  if (!Index || !Index->getType()->isIntegerTy(32) ||
+      !match(Index, m_c_Add(m_Specific(IndPhi), m_One())))
     return false;
 
-  // Ensure that when the Found and End blocks are identical the PHIs have the
-  // supported format. We don't currently allow cases like this:
-  // while.cond:
-  //   ...
-  //   br i1 %cmp.not, label %while.end, label %while.body
-  //
-  // while.body:
-  //   ...
-  //   br i1 %cmp.not2, label %while.cond, label %while.end
-  //
-  // while.end:
-  //   %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
-  //
-  // Where the incoming values for %final_ptr are unique and from each of the
-  // loop blocks, but not actually defined in the loop. This requires extra
-  // work setting up the byte.compare block, i.e. by introducing a select to
-  // choose the correct value.
-  // TODO: We could add support for this in future.
-  if (FoundBB == EndBB) {
-    for (PHINode &EndPN : EndBB->phis()) {
-      Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header);
-      Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB);
+  // If we match the pattern, IndPhi and Index will be replaced with the result
+  // of the mismatch. If any other instructions are used outside of the loop, we
+  // cannot replace it.
+  for (BasicBlock *BB : CurLoop->getBlocks())
+    for (Instruction &I : *BB)
+      if (&I != IndPhi && &I != Index)
+        for (User *U : I.users())
+          if (!CurLoop->contains(cast<Instruction>(U)))
+            return false;
 
-      // The value of the index when leaving the while.cond block is always the
-      // same as the end value (MaxLen) so we permit either. The value when
-      // leaving the while.body block should only be the index. Otherwise for
-      // any other values we only allow ones that are same for both blocks.
-      if (WhileCondVal != WhileBodyVal &&
-          ((WhileCondVal != Index && WhileCondVal != MaxLen) ||
-           (WhileBodyVal != Index)))
-        return false;
-    }
-  }
+  if (recognizePostIncMemCompare())
+    IncIdx = true;
+  else
+    return false;
 
   LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
                     << *(EndBB->getParent()) << "\n\n");
-
-  // The index is incremented before the GEP/Load pair so we need to
-  // add 1 to the start value.
-  transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true,
-                       FoundBB, EndBB);
   return true;
 }
 
-Value *AArch64LoopIdiomTransform::expandFindMismatch(
-    IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
-    GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
+Value *MemCompareIdiom::generateMemCompare(IRBuilder<> &Builder,
+                                           DomTreeUpdater &DTU) {
   Value *PtrA = GEPA->getPointerOperand();
   Value *PtrB = GEPB->getPointerOperand();
 
@@ -398,7 +467,7 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
   BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
   LLVMContext &Ctx = PHBranch->getContext();
   Type *LoadType = Type::getInt8Ty(Ctx);
-  Type *ResType = Builder.getInt32Ty();
+  Type *ResType = StartIdx->getType();
 
   // Split block in the original loop preheader.
   BasicBlock *EndBlock =
@@ -479,11 +548,11 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
 
   // Check the zero-extended iteration count > 0
   Builder.SetInsertPoint(MinItCheckBlock);
-  Value *ExtStart = Builder.CreateZExt(Start, I64Type);
+  Value *ExtStart = Builder.CreateZExt(StartIdx, I64Type);
   Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type);
   // This check doesn't really cost us very much.
 
-  Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen);
+  Value *LimitCheck = Builder.CreateICmpULE(StartIdx, MaxLen);
   BranchInst *MinItCheckBr =
       BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck);
   MinItCheckBr->setMetadata(
@@ -592,7 +661,8 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
   Value *SVERhsLoad = Builder.CreateMaskedLoad(SVELoadType, SVERhsGep, Align(1),
                                                LoopPred, Passthru);
 
-  Value *SVEMatchCmp = Builder.CreateICmpNE(SVELhsLoad, SVERhsLoad);
+  Value *SVEMatchCmp = Builder.CreateICmp(
+      ICmpInst::getInversePredicate(CmpPred), SVELhsLoad, SVERhsLoad);
   SVEMatchCmp = Builder.CreateSelect(LoopPred, SVEMatchCmp, PFalse);
   Value *SVEMatchHasActiveLanes = Builder.CreateOrReduce(SVEMatchCmp);
   BranchInst *SVEEarlyExit = BranchInst::Create(
@@ -658,7 +728,7 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
 
   Builder.SetInsertPoint(LoopStartBlock);
   PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index");
-  IndexPhi->addIncoming(Start, LoopPreHeaderBlock);
+  IndexPhi->addIncoming(StartIdx, LoopPreHeaderBlock);
 
   // Otherwise compare the values
   // Load bytes from each array and compare them.
@@ -674,7 +744,7 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
     cast<GetElementPtrInst>(RhsGep)->setIsInBounds(true);
   Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep);
 
-  Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
+  Value *MatchCmp = Builder.CreateICmp(CmpPred, LhsLoad, RhsLoad);
   // If we have a mismatch then exit the loop ...
   BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp);
   Builder.Insert(MatchCmpBr);
@@ -723,11 +793,7 @@ Value *AArch64LoopIdiomTransform::expandFindMismatch(
   return FinalRes;
 }
 
-void AArch64LoopIdiomTransform::transformByteCompare(
-    GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, PHINode *IndPhi,
-    Value *MaxLen, Instruction *Index, Value *Start, bool IncIdx,
-    BasicBlock *FoundBB, BasicBlock *EndBB) {
-
+void MemCompareIdiom::transform() {
   // Insert the byte compare code at the end of the preheader block
   BasicBlock *Preheader = CurLoop->getLoopPreheader();
   BasicBlock *Header = CurLoop->getHeader();
@@ -738,10 +804,10 @@ void AArch64LoopIdiomTransform::transformByteCompare(
 
   // Increment the pointer if this was done before the loads in the loop.
   if (IncIdx)
-    Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1));
+    StartIdx =
+        Builder.CreateAdd(StartIdx, ConstantInt::get(StartIdx->getType(), 1));
 
-  Value *ByteCmpRes =
-      expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen);
+  Value *ByteCmpRes = generateMemCompare(Builder, DTU);
 
   // Replaces uses of index & induction Phi with intrinsic (we already
   // checked that the the first instruction of Header is t...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/78471


More information about the llvm-commits mailing list