[llvm] [NFC][AArch64] Refactor AArch64LoopIdiomTransform in preparation for more idioms (PR #78471)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 23 05:55:28 PST 2024
================
@@ -251,145 +375,90 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
// %cmp.not.ld = icmp eq i8 %load.a, %load.b
// br i1 %cmp.not.ld, label %while.cond, label %while.end
//
- auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
- if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
+ if (!checkBlockSizes(CurLoop, 4, 7))
return false;
- // The incoming value to the PHI node from the loop should be an add of 1.
- Value *StartIdx = nullptr;
- Instruction *Index = nullptr;
- if (!CurLoop->contains(PN->getIncomingBlock(0))) {
- StartIdx = PN->getIncomingValue(0);
- Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
- } else {
- StartIdx = PN->getIncomingValue(1);
- Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
- }
-
- // Limit to 32-bit types for now
- if (!Index || !Index->getType()->isIntegerTy(32) ||
- !match(Index, m_c_Add(m_Specific(PN), m_One())))
- return false;
-
- // If we match the pattern, PN and Index will be replaced with the result of
- // the cttz.elts intrinsic. If any other instructions are used outside of
- // the loop, we cannot replace it.
- for (BasicBlock *BB : LoopBlocks)
- for (Instruction &I : *BB)
- if (&I != PN && &I != Index)
- for (User *U : I.users())
- if (!CurLoop->contains(cast<Instruction>(U)))
- return false;
-
// Match the branch instruction for the header
- ICmpInst::Predicate Pred;
- Value *MaxLen;
- BasicBlock *EndBB, *WhileBB;
- if (!match(Header->getTerminator(),
- m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
- m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
- Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB))
+ BasicBlock *Header = CurLoop->getHeader();
+ BasicBlock *WhileBodyBB;
+ if (!checkIterationCondition(Header, WhileBodyBB))
return false;
// WhileBB should contain the pattern of load & compare instructions. Match
// the pattern and find the GEP instructions used by the loads.
- ICmpInst::Predicate WhilePred;
- BasicBlock *FoundBB;
- BasicBlock *TrueBB;
Value *LoadA, *LoadB;
- if (!match(WhileBB->getTerminator(),
- m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)),
- m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
- WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB))
+ if (!checkLoadCompareCondition(WhileBodyBB, LoadA, LoadB))
return false;
- Value *A, *B;
- if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+ if (!areValidLoads(Index, LoadA, LoadB))
return false;
- LoadInst *LoadAI = cast<LoadInst>(LoadA);
- LoadInst *LoadBI = cast<LoadInst>(LoadB);
- if (!LoadAI->isSimple() || !LoadBI->isSimple())
+ if (!checkEndAndFoundBlockPhis(Index))
return false;
- GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
- GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
+ return true;
+}
- if (!GEPA || !GEPB)
+bool MemCompareIdiom::recognize() {
+ // Currently the transformation only works on scalable vector types, although
+ // there is no fundamental reason why it cannot be made to work for fixed
+ // width too.
+
+ // We also need to know the minimum page size for the target in order to
+ // generate runtime memory checks to ensure the vector version won't fault.
+ if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
+ DisableByteCmp)
return false;
- Value *PtrA = GEPA->getPointerOperand();
- Value *PtrB = GEPB->getPointerOperand();
+ BasicBlock *Header = CurLoop->getHeader();
- // Check we are loading i8 values from two loop invariant pointers
- if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
- !GEPA->getResultElementType()->isIntegerTy(8) ||
- !GEPB->getResultElementType()->isIntegerTy(8) ||
- !LoadAI->getType()->isIntegerTy(8) ||
- !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB)
+ // In AArch64LoopIdiomTransform::run we have already checked that the loop
+ // has a preheader so we can assume it's in a canonical form.
+ if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
return false;
- // Check that the index to the GEPs is the index we found earlier
- if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
+ // We only expect one PHI node for the index.
----------------
david-arm wrote:
Possibly? Although `getInductionVariable` calls `getLatchCmpInst` so I think we might have the same issue as mentioned above.
https://github.com/llvm/llvm-project/pull/78471
More information about the llvm-commits
mailing list