[llvm] [AArch64] Add MATCH loops to LoopIdiomVectorizePass (PR #101976)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 30 08:05:03 PST 2025
================
@@ -939,3 +988,400 @@ void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
report_fatal_error("Loops must remain in LCSSA form!");
}
}
+
+bool LoopIdiomVectorize::recognizeFindFirstByte() {
+ // Currently the transformation only works on scalable vector types, although
+ // there is no fundamental reason why it cannot be made to work for fixed
+ // vectors. We also need to know the target's minimum page size in order to
+ // generate runtime memory checks to ensure the vector version won't fault.
+ if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
+ DisableFindFirstByte)
+ return false;
+
+ // Define some constants we need throughout.
+ BasicBlock *Header = CurLoop->getHeader();
+ LLVMContext &Ctx = Header->getContext();
+
+ // We are expecting the four blocks defined below: Header, MatchBB, InnerBB,
+ // and OuterBB. For now, we will bail our for almost anything else. The Four
+ // blocks contain one nested loop.
+ if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4 ||
+ CurLoop->getSubLoops().size() != 1)
+ return false;
+
+ auto *InnerLoop = CurLoop->getSubLoops().front();
+ PHINode *IndPhi = dyn_cast<PHINode>(&Header->front());
+ if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
+ return false;
+
+ // Check instruction counts.
+ auto LoopBlocks = CurLoop->getBlocks();
+ if (LoopBlocks[0]->sizeWithoutDebug() > 3 ||
+ LoopBlocks[1]->sizeWithoutDebug() > 4 ||
+ LoopBlocks[2]->sizeWithoutDebug() > 3 ||
+ LoopBlocks[3]->sizeWithoutDebug() > 3)
+ return false;
+
+ // Check that no instruction other than IndPhi has outside uses.
+ for (BasicBlock *BB : LoopBlocks)
+ for (Instruction &I : *BB)
+ if (&I != IndPhi)
+ for (User *U : I.users())
+ if (!CurLoop->contains(cast<Instruction>(U)))
+ return false;
+
+ // Match the branch instruction in the header. We are expecting an
+ // unconditional branch to the inner loop.
+ //
+ // Header:
+ // %14 = phi ptr [ %24, %OuterBB ], [ %3, %Header.preheader ]
+ // %15 = load i8, ptr %14, align 1
+ // br label %MatchBB
+ BasicBlock *MatchBB;
+ if (!match(Header->getTerminator(), m_UnconditionalBr(MatchBB)) ||
+ !InnerLoop->contains(MatchBB))
+ return false;
+
+ // MatchBB should be the entrypoint into the inner loop containing the
+ // comparison between a search element and a needle.
+ //
+ // MatchBB:
+ // %20 = phi ptr [ %7, %Header ], [ %17, %InnerBB ]
+ // %21 = load i8, ptr %20, align 1
+ // %22 = icmp eq i8 %15, %21
+ // br i1 %22, label %ExitSucc, label %InnerBB
+ BasicBlock *ExitSucc, *InnerBB;
+ Value *LoadSearch, *LoadNeedle;
+ CmpPredicate MatchPred;
+ if (!match(MatchBB->getTerminator(),
+ m_Br(m_ICmp(MatchPred, m_Value(LoadSearch), m_Value(LoadNeedle)),
+ m_BasicBlock(ExitSucc), m_BasicBlock(InnerBB))) ||
+ MatchPred != ICmpInst::ICMP_EQ || !InnerLoop->contains(InnerBB))
+ return false;
+
+ // We expect outside uses of `IndPhi' in ExitSucc (and only there).
+ for (User *U : IndPhi->users())
+ if (!CurLoop->contains(cast<Instruction>(U))) {
+ auto *PN = dyn_cast<PHINode>(U);
+ if (!PN || PN->getParent() != ExitSucc)
+ return false;
+ }
+
+ // Match the loads and check they are simple.
+ Value *Search, *Needle;
+ if (!match(LoadSearch, m_Load(m_Value(Search))) ||
+ !match(LoadNeedle, m_Load(m_Value(Needle))) ||
+ !cast<LoadInst>(LoadSearch)->isSimple() ||
+ !cast<LoadInst>(LoadNeedle)->isSimple())
+ return false;
+
+ // Check we are loading valid characters.
+ Type *CharTy = LoadSearch->getType();
+ if (!CharTy->isIntegerTy() || LoadNeedle->getType() != CharTy)
+ return false;
+
+ // Pick the vectorisation factor based on CharTy, work out the cost of the
+ // match intrinsic and decide if we should use it.
+ // Note: For the time being we assume 128-bit vectors.
+ unsigned VF = 128 / CharTy->getIntegerBitWidth();
+ SmallVector<Type *> Args = {
+ ScalableVectorType::get(CharTy, VF), FixedVectorType::get(CharTy, VF),
+ ScalableVectorType::get(Type::getInt1Ty(Ctx), VF)};
+ IntrinsicCostAttributes Attrs(Intrinsic::experimental_vector_match, Args[2],
+ Args);
+ if (TTI->getIntrinsicInstrCost(Attrs, TTI::TCK_SizeAndLatency) > 4)
+ return false;
+
+ // The loads come from two PHIs, each with two incoming values.
+ PHINode *PSearch = dyn_cast<PHINode>(Search);
+ PHINode *PNeedle = dyn_cast<PHINode>(Needle);
+ if (!PSearch || PSearch->getNumIncomingValues() != 2 || !PNeedle ||
+ PNeedle->getNumIncomingValues() != 2)
+ return false;
+
+ // One PHI comes from the outer loop (PSearch), the other one from the inner
+ // loop (PNeedle). PSearch effectively corresponds to IndPhi.
+ if (InnerLoop->contains(PSearch))
+ std::swap(PSearch, PNeedle);
+ if (PSearch != &Header->front() || PNeedle != &MatchBB->front())
+ return false;
+
+ // The incoming values of both PHI nodes should be a gep of 1.
+ Value *SearchStart = PSearch->getIncomingValue(0);
+ Value *SearchIndex = PSearch->getIncomingValue(1);
+ if (CurLoop->contains(PSearch->getIncomingBlock(0)))
+ std::swap(SearchStart, SearchIndex);
+
+ Value *NeedleStart = PNeedle->getIncomingValue(0);
+ Value *NeedleIndex = PNeedle->getIncomingValue(1);
+ if (InnerLoop->contains(PNeedle->getIncomingBlock(0)))
+ std::swap(NeedleStart, NeedleIndex);
+
+ // Match the GEPs.
+ if (!match(SearchIndex, m_GEP(m_Specific(PSearch), m_One())) ||
+ !match(NeedleIndex, m_GEP(m_Specific(PNeedle), m_One())))
+ return false;
+
+ // Check the GEPs result type matches `CharTy'.
+ GetElementPtrInst *GEPSearch = cast<GetElementPtrInst>(SearchIndex);
+ GetElementPtrInst *GEPNeedle = cast<GetElementPtrInst>(NeedleIndex);
+ if (GEPSearch->getResultElementType() != CharTy ||
+ GEPNeedle->getResultElementType() != CharTy)
+ return false;
+
+ // InnerBB should increment the address of the needle pointer.
+ //
+ // InnerBB:
+ // %17 = getelementptr inbounds i8, ptr %20, i64 1
+ // %18 = icmp eq ptr %17, %10
+ // br i1 %18, label %OuterBB, label %MatchBB
+ BasicBlock *OuterBB;
+ Value *NeedleEnd;
+ if (!match(InnerBB->getTerminator(),
+ m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(GEPNeedle),
+ m_Value(NeedleEnd)),
+ m_BasicBlock(OuterBB), m_Specific(MatchBB))) ||
+ !CurLoop->contains(OuterBB))
+ return false;
+
+ // OuterBB should increment the address of the search element pointer.
+ //
+ // OuterBB:
+ // %24 = getelementptr inbounds i8, ptr %14, i64 1
+ // %25 = icmp eq ptr %24, %6
+ // br i1 %25, label %ExitFail, label %Header
+ BasicBlock *ExitFail;
+ Value *SearchEnd;
+ if (!match(OuterBB->getTerminator(),
+ m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(GEPSearch),
+ m_Value(SearchEnd)),
+ m_BasicBlock(ExitFail), m_Specific(Header))))
+ return false;
+
+ LLVM_DEBUG(dbgs() << "Found idiom in loop: \n" << *CurLoop << "\n\n");
+
+ transformFindFirstByte(IndPhi, VF, CharTy, ExitSucc, ExitFail, SearchStart,
+ SearchEnd, NeedleStart, NeedleEnd);
+ return true;
+}
+
+Value *LoopIdiomVectorize::expandFindFirstByte(
+ IRBuilder<> &Builder, DomTreeUpdater &DTU, unsigned VF, Type *CharTy,
+ BasicBlock *ExitSucc, BasicBlock *ExitFail, Value *SearchStart,
+ Value *SearchEnd, Value *NeedleStart, Value *NeedleEnd) {
+ // Set up some types and constants that we intend to reuse.
+ auto *PtrTy = Builder.getPtrTy();
+ auto *I64Ty = Builder.getInt64Ty();
+ auto *PredVTy = ScalableVectorType::get(Builder.getInt1Ty(), VF);
+ auto *CharVTy = ScalableVectorType::get(CharTy, VF);
+ auto *ConstVF = ConstantInt::get(I64Ty, VF);
+
+ // Other common arguments.
+ BasicBlock *Preheader = CurLoop->getLoopPreheader();
+ LLVMContext &Ctx = Preheader->getContext();
+ Value *Passthru = ConstantInt::getNullValue(CharVTy);
+
+ // Split block in the original loop preheader.
+ // SPH is the new preheader to the old scalar loop.
+ BasicBlock *SPH = SplitBlock(Preheader, Preheader->getTerminator(), DT, LI,
+ nullptr, "scalar_ph");
+
+ // Create the blocks that we're going to use.
+ //
+ // We will have the following loops:
+ // (O) Outer loop where we iterate over the elements of the search array.
+ // (I) Inner loop where we iterate over the elements of the needle array.
+ //
+ // Overall, the blocks do the following:
+ // (0) Check if the arrays can't cross page boundaries. If so go to (1),
+ // otherwise fall back to the original scalar loop.
+ // (1) Load the search array. Go to (2).
+ // (2) (a) Load the needle array.
+ // (b) Splat the first element to the inactive lanes.
+ // (c) Check if any elements match. If so go to (3), otherwise go to (4).
+ // (3) Compute the index of the first match and exit.
+ // (4) Check if we've reached the end of the needle array. If not loop back to
+ // (2), otherwise go to (5).
+ // (5) Check if we've reached the end of the search array. If not loop back to
+ // (1), otherwise exit.
+ // Blocks (0,3) are not part of any loop. Blocks (1,5) and (2,4) belong to
+ // the outer and inner loops, respectively.
+ BasicBlock *BB0 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
+ BasicBlock *BB1 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
+ BasicBlock *BB2 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
+ BasicBlock *BB3 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
+ BasicBlock *BB4 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
+ BasicBlock *BB5 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
+
+ // Update LoopInfo with the new loops.
+ auto OuterLoop = LI->AllocateLoop();
+ auto InnerLoop = LI->AllocateLoop();
+
+ if (auto ParentLoop = CurLoop->getParentLoop()) {
+ ParentLoop->addBasicBlockToLoop(BB0, *LI);
+ ParentLoop->addChildLoop(OuterLoop);
+ ParentLoop->addBasicBlockToLoop(BB3, *LI);
+ } else {
+ LI->addTopLevelLoop(OuterLoop);
+ }
+
+ // Add the inner loop to the outer.
+ OuterLoop->addChildLoop(InnerLoop);
+
+ // Add the new basic blocks to the corresponding loops.
+ OuterLoop->addBasicBlockToLoop(BB1, *LI);
+ OuterLoop->addBasicBlockToLoop(BB5, *LI);
+ InnerLoop->addBasicBlockToLoop(BB2, *LI);
+ InnerLoop->addBasicBlockToLoop(BB4, *LI);
+
+ // Update the terminator added by SplitBlock to branch to the first block.
+ Preheader->getTerminator()->setSuccessor(0, BB0);
+ DTU.applyUpdates({{DominatorTree::Delete, Preheader, SPH},
+ {DominatorTree::Insert, Preheader, BB0}});
+
+ // (0) Check if we could be crossing a page boundary; if so, fallback to the
+ // old scalar loops. Also create a predicate of VF elements to be used in the
+ // vector loops.
+ Builder.SetInsertPoint(BB0);
+ Value *ISearchStart = Builder.CreatePtrToInt(SearchStart, I64Ty);
+ Value *ISearchEnd = Builder.CreatePtrToInt(SearchEnd, I64Ty);
+ Value *INeedleStart = Builder.CreatePtrToInt(NeedleStart, I64Ty);
+ Value *INeedleEnd = Builder.CreatePtrToInt(NeedleEnd, I64Ty);
----------------
david-arm wrote:
Maybe you get this check for free due to how you've checked all the IR in the outer loop.
https://github.com/llvm/llvm-project/pull/101976
More information about the llvm-commits
mailing list