[llvm] [NFC][AArch64] Refactor AArch64LoopIdiomTransform in preparation for more idioms (PR #78471)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 23 07:54:43 PST 2024


================
@@ -202,32 +227,135 @@ bool AArch64LoopIdiomTransform::run(Loop *L) {
                     << CurLoop->getHeader()->getParent()->getName()
                     << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
 
-  return recognizeByteCompare();
+  MemCompareIdiom BCI(TTI, DT, LI, L);
+  if (BCI.recognize()) {
+    BCI.transform();
+    return true;
+  }
+
+  return false;
 }
 
-bool AArch64LoopIdiomTransform::recognizeByteCompare() {
-  // Currently the transformation only works on scalable vector types, although
-  // there is no fundamental reason why it cannot be made to work for fixed
-  // width too.
+bool MemCompareIdiom::areValidLoads(Value *BaseIdx, Value *LoadA,
+                                    Value *LoadB) {
+  Value *A, *B;
+  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+    return false;
 
-  // We also need to know the minimum page size for the target in order to
-  // generate runtime memory checks to ensure the vector version won't fault.
-  if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
-      DisableByteCmp)
+  if (!cast<LoadInst>(LoadA)->isSimple() || !cast<LoadInst>(LoadB)->isSimple())
     return false;
 
-  BasicBlock *Header = CurLoop->getHeader();
+  GEPA = dyn_cast<GetElementPtrInst>(A);
+  GEPB = dyn_cast<GetElementPtrInst>(B);
 
-  // In AArch64LoopIdiomTransform::run we have already checked that the loop
-  // has a preheader so we can assume it's in a canonical form.
-  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
+  if (!GEPA || !GEPB)
     return false;
 
-  PHINode *PN = dyn_cast<PHINode>(&Header->front());
-  if (!PN || PN->getNumIncomingValues() != 2)
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // Check we are loading i8 values from two loop invariant pointers
+  Type *MemType = GEPA->getResultElementType();
+  if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
+      !MemType->isIntegerTy(8) || GEPB->getResultElementType() != MemType ||
+      cast<LoadInst>(LoadA)->getType() != MemType ||
+      cast<LoadInst>(LoadB)->getType() != MemType || PtrA == PtrB)
+    return false;
+
+  // Check that the index to the GEPs is the index we found earlier
+  if (GEPA->getNumIndices() != 1 || GEPB->getNumIndices() != 1)
+    return false;
+
+  Value *IdxA = GEPA->getOperand(1);
+  Value *IdxB = GEPB->getOperand(1);
+
+  if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(BaseIdx))))
     return false;
 
+  return true;
+}
+
+static bool checkBlockSizes(Loop *CurLoop, unsigned Block1Size,
+                            unsigned Block2Size) {
   auto LoopBlocks = CurLoop->getBlocks();
+
+  auto BB1 = LoopBlocks[0]->instructionsWithoutDebug();
+  if (std::distance(BB1.begin(), BB1.end()) > Block1Size)
+    return false;
+
+  auto BB2 = LoopBlocks[1]->instructionsWithoutDebug();
+  if (std::distance(BB2.begin(), BB2.end()) > Block2Size)
+    return false;
+
+  return true;
+}
+
+bool MemCompareIdiom::checkIterationCondition(BasicBlock *CheckItBB,
+                                              BasicBlock *&LoopBB) {
+  ICmpInst::Predicate Pred;
+  if (!match(CheckItBB->getTerminator(),
----------------
david-arm wrote:

So this function does return something, but if you don't mind I've not used it here because I'd still have to do the work to extract the blocks in the branch, check the icmp predicate, Index, MaxLen, etc. I think the code is more compact right now, whereas using the latch function would require something like this:

```
  ICmpInst IC = CurLoop->getLatchCmpInst();
  if (!IC ||
      !match(CheckItBB->getTerminator(), m_Br(m_Specific(IC), m_BasicBlock(EndBB), m_BasicBlock(LoopBB))) ||
      IC->getOperand(0) != Index ||
      IC->getPredicate() != ICmpInst::Predicate::ICMP_EQ ||
      !CurLoop->contains(LoopBB))
     return false;

  MaxLen = IC->getOperand(1);
  return false;
```

https://github.com/llvm/llvm-project/pull/78471


More information about the llvm-commits mailing list