[llvm] [NFC][AArch64] Refactor AArch64LoopIdiomTransform in preparation for more idioms (PR #78471)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 23 05:55:28 PST 2024
================
@@ -202,32 +227,135 @@ bool AArch64LoopIdiomTransform::run(Loop *L) {
<< CurLoop->getHeader()->getParent()->getName()
<< "] Loop %" << CurLoop->getHeader()->getName() << "\n");
- return recognizeByteCompare();
+ MemCompareIdiom BCI(TTI, DT, LI, L);
+ if (BCI.recognize()) {
+ BCI.transform();
+ return true;
+ }
+
+ return false;
}
-bool AArch64LoopIdiomTransform::recognizeByteCompare() {
- // Currently the transformation only works on scalable vector types, although
- // there is no fundamental reason why it cannot be made to work for fixed
- // width too.
+bool MemCompareIdiom::areValidLoads(Value *BaseIdx, Value *LoadA,
+ Value *LoadB) {
+ Value *A, *B;
+ if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+ return false;
- // We also need to know the minimum page size for the target in order to
- // generate runtime memory checks to ensure the vector version won't fault.
- if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
- DisableByteCmp)
+ if (!cast<LoadInst>(LoadA)->isSimple() || !cast<LoadInst>(LoadB)->isSimple())
return false;
- BasicBlock *Header = CurLoop->getHeader();
+ GEPA = dyn_cast<GetElementPtrInst>(A);
+ GEPB = dyn_cast<GetElementPtrInst>(B);
- // In AArch64LoopIdiomTransform::run we have already checked that the loop
- // has a preheader so we can assume it's in a canonical form.
- if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
+ if (!GEPA || !GEPB)
return false;
- PHINode *PN = dyn_cast<PHINode>(&Header->front());
- if (!PN || PN->getNumIncomingValues() != 2)
+ Value *PtrA = GEPA->getPointerOperand();
+ Value *PtrB = GEPB->getPointerOperand();
+
+ // Check we are loading i8 values from two loop invariant pointers
+ Type *MemType = GEPA->getResultElementType();
+ if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
+ !MemType->isIntegerTy(8) || GEPB->getResultElementType() != MemType ||
+ cast<LoadInst>(LoadA)->getType() != MemType ||
+ cast<LoadInst>(LoadB)->getType() != MemType || PtrA == PtrB)
+ return false;
+
+ // Check that the index to the GEPs is the index we found earlier
+ if (GEPA->getNumIndices() != 1 || GEPB->getNumIndices() != 1)
+ return false;
+
+ Value *IdxA = GEPA->getOperand(1);
+ Value *IdxB = GEPB->getOperand(1);
+
+ if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(BaseIdx))))
return false;
+ return true;
+}
+
+static bool checkBlockSizes(Loop *CurLoop, unsigned Block1Size,
+ unsigned Block2Size) {
auto LoopBlocks = CurLoop->getBlocks();
+
+ auto BB1 = LoopBlocks[0]->instructionsWithoutDebug();
+ if (std::distance(BB1.begin(), BB1.end()) > Block1Size)
+ return false;
+
+ auto BB2 = LoopBlocks[1]->instructionsWithoutDebug();
+ if (std::distance(BB2.begin(), BB2.end()) > Block2Size)
+ return false;
+
+ return true;
+}
+
+bool MemCompareIdiom::checkIterationCondition(BasicBlock *CheckItBB,
+ BasicBlock *&LoopBB) {
+ ICmpInst::Predicate Pred;
+ if (!match(CheckItBB->getTerminator(),
----------------
david-arm wrote:
I think I might have looked at this before and it requires loops to have a single latch. The icmp in this case occurs in the first block, not the last block. I'll double check, but I think there was a reason I couldn't use that.
https://github.com/llvm/llvm-project/pull/78471
More information about the llvm-commits
mailing list