[llvm] [NFC][AArch64] Refactor AArch64LoopIdiomTransform in preparation for more idioms (PR #78471)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 23 07:54:43 PST 2024


================
@@ -251,145 +375,90 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
   //   br i1 %cmp.not.ld, label %while.cond, label %while.end
   //
-  auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
-  if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
+  if (!checkBlockSizes(CurLoop, 4, 7))
     return false;
 
-  // The incoming value to the PHI node from the loop should be an add of 1.
-  Value *StartIdx = nullptr;
-  Instruction *Index = nullptr;
-  if (!CurLoop->contains(PN->getIncomingBlock(0))) {
-    StartIdx = PN->getIncomingValue(0);
-    Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
-  } else {
-    StartIdx = PN->getIncomingValue(1);
-    Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
-  }
-
-  // Limit to 32-bit types for now
-  if (!Index || !Index->getType()->isIntegerTy(32) ||
-      !match(Index, m_c_Add(m_Specific(PN), m_One())))
-    return false;
-
-  // If we match the pattern, PN and Index will be replaced with the result of
-  // the cttz.elts intrinsic. If any other instructions are used outside of
-  // the loop, we cannot replace it.
-  for (BasicBlock *BB : LoopBlocks)
-    for (Instruction &I : *BB)
-      if (&I != PN && &I != Index)
-        for (User *U : I.users())
-          if (!CurLoop->contains(cast<Instruction>(U)))
-            return false;
-
   // Match the branch instruction for the header
-  ICmpInst::Predicate Pred;
-  Value *MaxLen;
-  BasicBlock *EndBB, *WhileBB;
-  if (!match(Header->getTerminator(),
-             m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
-                  m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
-      Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB))
+  BasicBlock *Header = CurLoop->getHeader();
+  BasicBlock *WhileBodyBB;
+  if (!checkIterationCondition(Header, WhileBodyBB))
     return false;
 
   // WhileBB should contain the pattern of load & compare instructions. Match
   // the pattern and find the GEP instructions used by the loads.
-  ICmpInst::Predicate WhilePred;
-  BasicBlock *FoundBB;
-  BasicBlock *TrueBB;
   Value *LoadA, *LoadB;
-  if (!match(WhileBB->getTerminator(),
-             m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)),
-                  m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
-      WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB))
+  if (!checkLoadCompareCondition(WhileBodyBB, LoadA, LoadB))
     return false;
 
-  Value *A, *B;
-  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+  if (!areValidLoads(Index, LoadA, LoadB))
     return false;
 
-  LoadInst *LoadAI = cast<LoadInst>(LoadA);
-  LoadInst *LoadBI = cast<LoadInst>(LoadB);
-  if (!LoadAI->isSimple() || !LoadBI->isSimple())
+  if (!checkEndAndFoundBlockPhis(Index))
     return false;
 
-  GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
-  GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
+  return true;
+}
 
-  if (!GEPA || !GEPB)
+bool MemCompareIdiom::recognize() {
+  // Currently the transformation only works on scalable vector types, although
+  // there is no fundamental reason why it cannot be made to work for fixed
+  // width too.
+
+  // We also need to know the minimum page size for the target in order to
+  // generate runtime memory checks to ensure the vector version won't fault.
+  if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
+      DisableByteCmp)
     return false;
 
-  Value *PtrA = GEPA->getPointerOperand();
-  Value *PtrB = GEPB->getPointerOperand();
+  BasicBlock *Header = CurLoop->getHeader();
 
-  // Check we are loading i8 values from two loop invariant pointers
-  if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
-      !GEPA->getResultElementType()->isIntegerTy(8) ||
-      !GEPB->getResultElementType()->isIntegerTy(8) ||
-      !LoadAI->getType()->isIntegerTy(8) ||
-      !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB)
+  // In AArch64LoopIdiomTransform::run we have already checked that the loop
+  // has a preheader so we can assume it's in a canonical form.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
     return false;
 
-  // Check that the index to the GEPs is the index we found earlier
-  if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
+  // We only expect one PHI node for the index.
----------------
david-arm wrote:

So I think this is possible, but again if you don't mind I've kept the original code. The reason is that `getInductionVariable` requires initialising the ScalarEvolution class, doing work to create SCEVAddRecAddr and InductionDescriptor objects for nodes that we're going to throw away and ignore. It's also far more powerful than we need, since it supports more induction variables than the narrow set we care about, i.e. it supports FP inductions, increments != 1, non-linear inductions, etc. It feels a shame to do all that work, and then afterwards restrict it to the very narrow set we actually care about for these tight loops.

https://github.com/llvm/llvm-project/pull/78471


More information about the llvm-commits mailing list