[llvm] [NFC][AArch64] Refactor AArch64LoopIdiomTransform in preparation for more idioms (PR #78471)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 23 05:55:29 PST 2024
================
@@ -251,145 +375,90 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
// %cmp.not.ld = icmp eq i8 %load.a, %load.b
// br i1 %cmp.not.ld, label %while.cond, label %while.end
//
- auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
- if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
+ if (!checkBlockSizes(CurLoop, 4, 7))
return false;
- // The incoming value to the PHI node from the loop should be an add of 1.
- Value *StartIdx = nullptr;
- Instruction *Index = nullptr;
- if (!CurLoop->contains(PN->getIncomingBlock(0))) {
- StartIdx = PN->getIncomingValue(0);
- Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
- } else {
- StartIdx = PN->getIncomingValue(1);
- Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
- }
-
- // Limit to 32-bit types for now
- if (!Index || !Index->getType()->isIntegerTy(32) ||
- !match(Index, m_c_Add(m_Specific(PN), m_One())))
- return false;
-
- // If we match the pattern, PN and Index will be replaced with the result of
- // the cttz.elts intrinsic. If any other instructions are used outside of
- // the loop, we cannot replace it.
- for (BasicBlock *BB : LoopBlocks)
- for (Instruction &I : *BB)
- if (&I != PN && &I != Index)
- for (User *U : I.users())
- if (!CurLoop->contains(cast<Instruction>(U)))
- return false;
-
// Match the branch instruction for the header
- ICmpInst::Predicate Pred;
- Value *MaxLen;
- BasicBlock *EndBB, *WhileBB;
- if (!match(Header->getTerminator(),
- m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
- m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
- Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB))
+ BasicBlock *Header = CurLoop->getHeader();
+ BasicBlock *WhileBodyBB;
+ if (!checkIterationCondition(Header, WhileBodyBB))
return false;
// WhileBB should contain the pattern of load & compare instructions. Match
// the pattern and find the GEP instructions used by the loads.
- ICmpInst::Predicate WhilePred;
- BasicBlock *FoundBB;
- BasicBlock *TrueBB;
Value *LoadA, *LoadB;
- if (!match(WhileBB->getTerminator(),
- m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)),
- m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
- WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB))
+ if (!checkLoadCompareCondition(WhileBodyBB, LoadA, LoadB))
return false;
- Value *A, *B;
- if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+ if (!areValidLoads(Index, LoadA, LoadB))
return false;
- LoadInst *LoadAI = cast<LoadInst>(LoadA);
- LoadInst *LoadBI = cast<LoadInst>(LoadB);
- if (!LoadAI->isSimple() || !LoadBI->isSimple())
+ if (!checkEndAndFoundBlockPhis(Index))
return false;
- GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
- GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
+ return true;
+}
- if (!GEPA || !GEPB)
+bool MemCompareIdiom::recognize() {
+ // Currently the transformation only works on scalable vector types, although
+ // there is no fundamental reason why it cannot be made to work for fixed
+ // width too.
+
+ // We also need to know the minimum page size for the target in order to
+ // generate runtime memory checks to ensure the vector version won't fault.
+ if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
+ DisableByteCmp)
return false;
- Value *PtrA = GEPA->getPointerOperand();
- Value *PtrB = GEPB->getPointerOperand();
+ BasicBlock *Header = CurLoop->getHeader();
- // Check we are loading i8 values from two loop invariant pointers
- if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
- !GEPA->getResultElementType()->isIntegerTy(8) ||
- !GEPB->getResultElementType()->isIntegerTy(8) ||
- !LoadAI->getType()->isIntegerTy(8) ||
- !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB)
+ // In AArch64LoopIdiomTransform::run we have already checked that the loop
+ // has a preheader so we can assume it's in a canonical form.
+ if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
return false;
- // Check that the index to the GEPs is the index we found earlier
- if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
+ // We only expect one PHI node for the index.
+ IndPhi = dyn_cast<PHINode>(&Header->front());
+ if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
return false;
- Value *IdxA = GEPA->getOperand(GEPA->getNumIndices());
- Value *IdxB = GEPB->getOperand(GEPB->getNumIndices());
- if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index))))
- return false;
+ // The incoming value to the PHI node from the loop should be an add of 1.
+ StartIdx = nullptr;
+ Index = nullptr;
+ if (!CurLoop->contains(IndPhi->getIncomingBlock(0))) {
+ StartIdx = IndPhi->getIncomingValue(0);
+ Index = dyn_cast<Instruction>(IndPhi->getIncomingValue(1));
+ } else {
+ StartIdx = IndPhi->getIncomingValue(1);
+ Index = dyn_cast<Instruction>(IndPhi->getIncomingValue(0));
+ }
- // We only ever expect the pre-incremented index value to be used inside the
- // loop.
- if (!PN->hasOneUse())
+ if (!Index || !Index->getType()->isIntegerTy(32) ||
+ !match(Index, m_c_Add(m_Specific(IndPhi), m_One())))
return false;
- // Ensure that when the Found and End blocks are identical the PHIs have the
- // supported format. We don't currently allow cases like this:
- // while.cond:
- // ...
- // br i1 %cmp.not, label %while.end, label %while.body
- //
- // while.body:
- // ...
- // br i1 %cmp.not2, label %while.cond, label %while.end
- //
- // while.end:
- // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
- //
- // Where the incoming values for %final_ptr are unique and from each of the
- // loop blocks, but not actually defined in the loop. This requires extra
- // work setting up the byte.compare block, i.e. by introducing a select to
- // choose the correct value.
- // TODO: We could add support for this in future.
- if (FoundBB == EndBB) {
- for (PHINode &EndPN : EndBB->phis()) {
- Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header);
- Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB);
+ // If we match the pattern, IndPhi and Index will be replaced with the result
+ // of the mismatch. If any other instructions are used outside of the loop, we
----------------
david-arm wrote:
I think it refers to initial versions of the patch - I'll fix it!
https://github.com/llvm/llvm-project/pull/78471
More information about the llvm-commits
mailing list