[llvm] [NFC][AArch64] Refactor AArch64LoopIdiomTransform in preparation for more idioms (PR #78471)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 23 05:55:28 PST 2024


================
@@ -202,32 +227,135 @@ bool AArch64LoopIdiomTransform::run(Loop *L) {
                     << CurLoop->getHeader()->getParent()->getName()
                     << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
 
-  return recognizeByteCompare();
+  MemCompareIdiom BCI(TTI, DT, LI, L);
+  if (BCI.recognize()) {
+    BCI.transform();
+    return true;
+  }
+
+  return false;
 }
 
-bool AArch64LoopIdiomTransform::recognizeByteCompare() {
-  // Currently the transformation only works on scalable vector types, although
-  // there is no fundamental reason why it cannot be made to work for fixed
-  // width too.
+bool MemCompareIdiom::areValidLoads(Value *BaseIdx, Value *LoadA,
+                                    Value *LoadB) {
+  Value *A, *B;
+  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+    return false;
 
-  // We also need to know the minimum page size for the target in order to
-  // generate runtime memory checks to ensure the vector version won't fault.
-  if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
-      DisableByteCmp)
+  if (!cast<LoadInst>(LoadA)->isSimple() || !cast<LoadInst>(LoadB)->isSimple())
     return false;
 
-  BasicBlock *Header = CurLoop->getHeader();
+  GEPA = dyn_cast<GetElementPtrInst>(A);
+  GEPB = dyn_cast<GetElementPtrInst>(B);
 
-  // In AArch64LoopIdiomTransform::run we have already checked that the loop
-  // has a preheader so we can assume it's in a canonical form.
-  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
+  if (!GEPA || !GEPB)
     return false;
 
-  PHINode *PN = dyn_cast<PHINode>(&Header->front());
-  if (!PN || PN->getNumIncomingValues() != 2)
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // Check we are loading i8 values from two loop invariant pointers
+  Type *MemType = GEPA->getResultElementType();
+  if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
+      !MemType->isIntegerTy(8) || GEPB->getResultElementType() != MemType ||
+      cast<LoadInst>(LoadA)->getType() != MemType ||
+      cast<LoadInst>(LoadB)->getType() != MemType || PtrA == PtrB)
+    return false;
+
+  // Check that the index to the GEPs is the index we found earlier
+  if (GEPA->getNumIndices() != 1 || GEPB->getNumIndices() != 1)
+    return false;
+
+  Value *IdxA = GEPA->getOperand(1);
+  Value *IdxB = GEPB->getOperand(1);
+
+  if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(BaseIdx))))
     return false;
 
+  return true;
+}
+
+static bool checkBlockSizes(Loop *CurLoop, unsigned Block1Size,
+                            unsigned Block2Size) {
   auto LoopBlocks = CurLoop->getBlocks();
+
+  auto BB1 = LoopBlocks[0]->instructionsWithoutDebug();
+  if (std::distance(BB1.begin(), BB1.end()) > Block1Size)
+    return false;
+
+  auto BB2 = LoopBlocks[1]->instructionsWithoutDebug();
+  if (std::distance(BB2.begin(), BB2.end()) > Block2Size)
+    return false;
+
+  return true;
+}
+
+bool MemCompareIdiom::checkIterationCondition(BasicBlock *CheckItBB,
+                                              BasicBlock *&LoopBB) {
+  ICmpInst::Predicate Pred;
+  if (!match(CheckItBB->getTerminator(),
----------------
david-arm wrote:

I think I might have looked at this before and it requires loops to have a single latch. The icmp in this case occurs in the first block, not the last block. I'll double check, but I think there was a reason I couldn't use that.

https://github.com/llvm/llvm-project/pull/78471


More information about the llvm-commits mailing list