[llvm] [NFC][AArch64] Refactor AArch64LoopIdiomTransform in preparation for more idioms (PR #78471)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 23 05:55:28 PST 2024


================
@@ -251,145 +375,90 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
   //   br i1 %cmp.not.ld, label %while.cond, label %while.end
   //
-  auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
-  if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
+  if (!checkBlockSizes(CurLoop, 4, 7))
     return false;
 
-  // The incoming value to the PHI node from the loop should be an add of 1.
-  Value *StartIdx = nullptr;
-  Instruction *Index = nullptr;
-  if (!CurLoop->contains(PN->getIncomingBlock(0))) {
-    StartIdx = PN->getIncomingValue(0);
-    Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
-  } else {
-    StartIdx = PN->getIncomingValue(1);
-    Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
-  }
-
-  // Limit to 32-bit types for now
-  if (!Index || !Index->getType()->isIntegerTy(32) ||
-      !match(Index, m_c_Add(m_Specific(PN), m_One())))
-    return false;
-
-  // If we match the pattern, PN and Index will be replaced with the result of
-  // the cttz.elts intrinsic. If any other instructions are used outside of
-  // the loop, we cannot replace it.
-  for (BasicBlock *BB : LoopBlocks)
-    for (Instruction &I : *BB)
-      if (&I != PN && &I != Index)
-        for (User *U : I.users())
-          if (!CurLoop->contains(cast<Instruction>(U)))
-            return false;
-
   // Match the branch instruction for the header
-  ICmpInst::Predicate Pred;
-  Value *MaxLen;
-  BasicBlock *EndBB, *WhileBB;
-  if (!match(Header->getTerminator(),
-             m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
-                  m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
-      Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB))
+  BasicBlock *Header = CurLoop->getHeader();
+  BasicBlock *WhileBodyBB;
+  if (!checkIterationCondition(Header, WhileBodyBB))
     return false;
 
   // WhileBB should contain the pattern of load & compare instructions. Match
   // the pattern and find the GEP instructions used by the loads.
-  ICmpInst::Predicate WhilePred;
-  BasicBlock *FoundBB;
-  BasicBlock *TrueBB;
   Value *LoadA, *LoadB;
-  if (!match(WhileBB->getTerminator(),
-             m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)),
-                  m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
-      WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB))
+  if (!checkLoadCompareCondition(WhileBodyBB, LoadA, LoadB))
     return false;
 
-  Value *A, *B;
-  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+  if (!areValidLoads(Index, LoadA, LoadB))
     return false;
 
-  LoadInst *LoadAI = cast<LoadInst>(LoadA);
-  LoadInst *LoadBI = cast<LoadInst>(LoadB);
-  if (!LoadAI->isSimple() || !LoadBI->isSimple())
+  if (!checkEndAndFoundBlockPhis(Index))
     return false;
 
-  GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
-  GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
+  return true;
+}
 
-  if (!GEPA || !GEPB)
+bool MemCompareIdiom::recognize() {
+  // Currently the transformation only works on scalable vector types, although
+  // there is no fundamental reason why it cannot be made to work for fixed
+  // width too.
+
+  // We also need to know the minimum page size for the target in order to
+  // generate runtime memory checks to ensure the vector version won't fault.
+  if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
+      DisableByteCmp)
     return false;
 
-  Value *PtrA = GEPA->getPointerOperand();
-  Value *PtrB = GEPB->getPointerOperand();
+  BasicBlock *Header = CurLoop->getHeader();
 
-  // Check we are loading i8 values from two loop invariant pointers
-  if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
-      !GEPA->getResultElementType()->isIntegerTy(8) ||
-      !GEPB->getResultElementType()->isIntegerTy(8) ||
-      !LoadAI->getType()->isIntegerTy(8) ||
-      !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB)
+  // In AArch64LoopIdiomTransform::run we have already checked that the loop
+  // has a preheader so we can assume it's in a canonical form.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
     return false;
 
-  // Check that the index to the GEPs is the index we found earlier
-  if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
+  // We only expect one PHI node for the index.
----------------
david-arm wrote:

Possibly? Although `getInductionVariable` calls `getLatchCmpInst` so I think we might have the same issue as mentioned above.

https://github.com/llvm/llvm-project/pull/78471


More information about the llvm-commits mailing list