[llvm] [AArch64] Add MATCH loops to LoopIdiomVectorizePass (PR #101976)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 6 06:57:15 PST 2025


================
@@ -939,3 +988,432 @@ void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
       report_fatal_error("Loops must remain in LCSSA form!");
   }
 }
+
+bool LoopIdiomVectorize::recognizeFindFirstByte() {
+  // Currently the transformation only works on scalable vector types, although
+  // there is no fundamental reason why it cannot be made to work for fixed
+  // vectors. We also need to know the target's minimum page size in order to
+  // generate runtime memory checks to ensure the vector version won't fault.
+  if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
+      DisableFindFirstByte)
+    return false;
+
+  // Define some constants we need throughout.
+  BasicBlock *Header = CurLoop->getHeader();
+  LLVMContext &Ctx = Header->getContext();
+
+  // We are expecting the four blocks defined below: Header, MatchBB, InnerBB,
+  // and OuterBB. For now, we will bail our for almost anything else. The Four
+  // blocks contain one nested loop.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4 ||
+      CurLoop->getSubLoops().size() != 1)
+    return false;
+
+  auto *InnerLoop = CurLoop->getSubLoops().front();
+  PHINode *IndPhi = dyn_cast<PHINode>(&Header->front());
+  if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
+    return false;
+
+  // Check instruction counts.
+  auto LoopBlocks = CurLoop->getBlocks();
+  if (LoopBlocks[0]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[1]->sizeWithoutDebug() > 4 ||
+      LoopBlocks[2]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[3]->sizeWithoutDebug() > 3)
+    return false;
+
+  // Check that no instruction other than IndPhi has outside uses.
+  for (BasicBlock *BB : LoopBlocks)
+    for (Instruction &I : *BB)
+      if (&I != IndPhi)
+        for (User *U : I.users())
+          if (!CurLoop->contains(cast<Instruction>(U)))
+            return false;
+
+  // Match the branch instruction in the header. We are expecting an
+  // unconditional branch to the inner loop.
+  //
+  // Header:
+  //   %14 = phi ptr [ %24, %OuterBB ], [ %3, %Header.preheader ]
+  //   %15 = load i8, ptr %14, align 1
+  //   br label %MatchBB
+  BasicBlock *MatchBB;
+  if (!match(Header->getTerminator(), m_UnconditionalBr(MatchBB)) ||
+      !InnerLoop->contains(MatchBB))
+    return false;
+
+  // MatchBB should be the entrypoint into the inner loop containing the
+  // comparison between a search element and a needle.
+  //
+  // MatchBB:
+  //   %20 = phi ptr [ %7, %Header ], [ %17, %InnerBB ]
+  //   %21 = load i8, ptr %20, align 1
+  //   %22 = icmp eq i8 %15, %21
+  //   br i1 %22, label %ExitSucc, label %InnerBB
+  BasicBlock *ExitSucc, *InnerBB;
+  Value *LoadSearch, *LoadNeedle;
+  CmpPredicate MatchPred;
+  if (!match(MatchBB->getTerminator(),
+             m_Br(m_ICmp(MatchPred, m_Value(LoadSearch), m_Value(LoadNeedle)),
+                  m_BasicBlock(ExitSucc), m_BasicBlock(InnerBB))) ||
+      MatchPred != ICmpInst::ICMP_EQ || !InnerLoop->contains(InnerBB))
+    return false;
+
+  // We expect outside uses of `IndPhi' in ExitSucc (and only there).
+  for (User *U : IndPhi->users())
+    if (!CurLoop->contains(cast<Instruction>(U))) {
+      auto *PN = dyn_cast<PHINode>(U);
+      if (!PN || PN->getParent() != ExitSucc)
+        return false;
+    }
+
+  // Match the loads and check they are simple.
+  Value *Search, *Needle;
+  if (!match(LoadSearch, m_Load(m_Value(Search))) ||
+      !match(LoadNeedle, m_Load(m_Value(Needle))) ||
+      !cast<LoadInst>(LoadSearch)->isSimple() ||
+      !cast<LoadInst>(LoadNeedle)->isSimple())
+    return false;
+
+  // Check we are loading valid characters.
+  Type *CharTy = LoadSearch->getType();
+  if (!CharTy->isIntegerTy() || LoadNeedle->getType() != CharTy)
+    return false;
+
+  // Pick the vectorisation factor based on CharTy, work out the cost of the
+  // match intrinsic and decide if we should use it.
+  // Note: For the time being we assume 128-bit vectors.
+  unsigned VF = 128 / CharTy->getIntegerBitWidth();
+  SmallVector<Type *> Args = {
+      ScalableVectorType::get(CharTy, VF), FixedVectorType::get(CharTy, VF),
+      ScalableVectorType::get(Type::getInt1Ty(Ctx), VF)};
+  IntrinsicCostAttributes Attrs(Intrinsic::experimental_vector_match, Args[2],
+                                Args);
+  if (TTI->getIntrinsicInstrCost(Attrs, TTI::TCK_SizeAndLatency) > 4)
+    return false;
+
+  // The loads come from two PHIs, each with two incoming values.
+  PHINode *PSearch = dyn_cast<PHINode>(Search);
+  PHINode *PNeedle = dyn_cast<PHINode>(Needle);
+  if (!PSearch || PSearch->getNumIncomingValues() != 2 || !PNeedle ||
+      PNeedle->getNumIncomingValues() != 2)
+    return false;
+
+  // One PHI comes from the outer loop (PSearch), the other one from the inner
+  // loop (PNeedle). PSearch effectively corresponds to IndPhi.
+  if (InnerLoop->contains(PSearch))
+    std::swap(PSearch, PNeedle);
+  if (PSearch != &Header->front() || PNeedle != &MatchBB->front())
+    return false;
+
+  // The incoming values of both PHI nodes should be a gep of 1.
+  Value *SearchStart = PSearch->getIncomingValue(0);
+  Value *SearchIndex = PSearch->getIncomingValue(1);
+  if (CurLoop->contains(PSearch->getIncomingBlock(0)))
+    std::swap(SearchStart, SearchIndex);
+
+  Value *NeedleStart = PNeedle->getIncomingValue(0);
+  Value *NeedleIndex = PNeedle->getIncomingValue(1);
+  if (InnerLoop->contains(PNeedle->getIncomingBlock(0)))
+    std::swap(NeedleStart, NeedleIndex);
+
+  // Match the GEPs.
+  if (!match(SearchIndex, m_GEP(m_Specific(PSearch), m_One())) ||
+      !match(NeedleIndex, m_GEP(m_Specific(PNeedle), m_One())))
+    return false;
+
+  // Check the GEPs result type matches `CharTy'.
+  GetElementPtrInst *GEPSearch = cast<GetElementPtrInst>(SearchIndex);
+  GetElementPtrInst *GEPNeedle = cast<GetElementPtrInst>(NeedleIndex);
+  if (GEPSearch->getResultElementType() != CharTy ||
+      GEPNeedle->getResultElementType() != CharTy)
+    return false;
+
+  // InnerBB should increment the address of the needle pointer.
+  //
+  // InnerBB:
+  //   %17 = getelementptr inbounds i8, ptr %20, i64 1
+  //   %18 = icmp eq ptr %17, %10
+  //   br i1 %18, label %OuterBB, label %MatchBB
+  BasicBlock *OuterBB;
+  Value *NeedleEnd;
+  if (!match(InnerBB->getTerminator(),
+             m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(GEPNeedle),
+                                 m_Value(NeedleEnd)),
+                  m_BasicBlock(OuterBB), m_Specific(MatchBB))) ||
+      !CurLoop->contains(OuterBB))
+    return false;
+
+  // OuterBB should increment the address of the search element pointer.
+  //
+  // OuterBB:
+  //   %24 = getelementptr inbounds i8, ptr %14, i64 1
+  //   %25 = icmp eq ptr %24, %6
+  //   br i1 %25, label %ExitFail, label %Header
+  BasicBlock *ExitFail;
+  Value *SearchEnd;
+  if (!match(OuterBB->getTerminator(),
+             m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(GEPSearch),
+                                 m_Value(SearchEnd)),
+                  m_BasicBlock(ExitFail), m_Specific(Header))))
+    return false;
+
+  if (!CurLoop->isLoopInvariant(SearchStart) ||
+      !CurLoop->isLoopInvariant(SearchEnd) ||
+      !CurLoop->isLoopInvariant(NeedleStart) ||
+      !CurLoop->isLoopInvariant(NeedleEnd))
+    return false;
+
+  LLVM_DEBUG(dbgs() << "Found idiom in loop: \n" << *CurLoop << "\n\n");
+
+  transformFindFirstByte(IndPhi, VF, CharTy, ExitSucc, ExitFail, SearchStart,
+                         SearchEnd, NeedleStart, NeedleEnd);
+  return true;
+}
+
+Value *LoopIdiomVectorize::expandFindFirstByte(
+    IRBuilder<> &Builder, DomTreeUpdater &DTU, unsigned VF, Type *CharTy,
+    BasicBlock *ExitSucc, BasicBlock *ExitFail, Value *SearchStart,
+    Value *SearchEnd, Value *NeedleStart, Value *NeedleEnd) {
+  // Set up some types and constants that we intend to reuse.
+  auto *PtrTy = Builder.getPtrTy();
+  auto *I64Ty = Builder.getInt64Ty();
+  auto *PredVTy = ScalableVectorType::get(Builder.getInt1Ty(), VF);
+  auto *CharVTy = ScalableVectorType::get(CharTy, VF);
+  auto *ConstVF = ConstantInt::get(I64Ty, VF);
+
+  // Other common arguments.
+  BasicBlock *Preheader = CurLoop->getLoopPreheader();
+  LLVMContext &Ctx = Preheader->getContext();
+  Value *Passthru = ConstantInt::getNullValue(CharVTy);
+
+  // Split block in the original loop preheader.
+  // SPH is the new preheader to the old scalar loop.
+  BasicBlock *SPH = SplitBlock(Preheader, Preheader->getTerminator(), DT, LI,
+                               nullptr, "scalar_preheader");
+
+  // Create the blocks that we're going to use.
+  //
+  // We will have the following loops:
+  // (O) Outer loop where we iterate over the elements of the search array.
+  // (I) Inner loop where we iterate over the elements of the needle array.
+  //
+  // Overall, the blocks do the following:
+  // (0) Check if the arrays can't cross page boundaries. If so go to (1),
+  //     otherwise fall back to the original scalar loop.
+  // (1) Load the search array. Go to (2).
+  // (2) (a) Load the needle array.
+  //     (b) Splat the first element to the inactive lanes.
+  //     (c) Check if any elements match. If so go to (3), otherwise go to (4).
+  // (3) Compute the index of the first match and exit.
+  // (4) Check if we've reached the end of the needle array. If not loop back to
+  //     (2), otherwise go to (5).
+  // (5) Check if we've reached the end of the search array. If not loop back to
+  //     (1), otherwise exit.
+  // Blocks (0,3) are not part of any loop. Blocks (1,5) and (2,4) belong to
+  // the outer and inner loops, respectively.
+  BasicBlock *BB0 = BasicBlock::Create(Ctx, "mem_check", SPH->getParent(), SPH);
+  BasicBlock *BB1 =
+      BasicBlock::Create(Ctx, "find_first_vec_header", SPH->getParent(), SPH);
+  BasicBlock *BB2 =
+      BasicBlock::Create(Ctx, "match_check_vec", SPH->getParent(), SPH);
+  BasicBlock *BB3 =
+      BasicBlock::Create(Ctx, "calculate_match", SPH->getParent(), SPH);
+  BasicBlock *BB4 =
+      BasicBlock::Create(Ctx, "needle_check_vec", SPH->getParent(), SPH);
+  BasicBlock *BB5 =
+      BasicBlock::Create(Ctx, "search_check_vec", SPH->getParent(), SPH);
+
+  // Update LoopInfo with the new loops.
+  auto OuterLoop = LI->AllocateLoop();
+  auto InnerLoop = LI->AllocateLoop();
+
+  if (auto ParentLoop = CurLoop->getParentLoop()) {
+    ParentLoop->addBasicBlockToLoop(BB0, *LI);
+    ParentLoop->addChildLoop(OuterLoop);
+    ParentLoop->addBasicBlockToLoop(BB3, *LI);
+  } else {
+    LI->addTopLevelLoop(OuterLoop);
+  }
+
+  // Add the inner loop to the outer.
+  OuterLoop->addChildLoop(InnerLoop);
+
+  // Add the new basic blocks to the corresponding loops.
+  OuterLoop->addBasicBlockToLoop(BB1, *LI);
+  OuterLoop->addBasicBlockToLoop(BB5, *LI);
+  InnerLoop->addBasicBlockToLoop(BB2, *LI);
+  InnerLoop->addBasicBlockToLoop(BB4, *LI);
+
+  // Update the terminator added by SplitBlock to branch to the first block.
+  Preheader->getTerminator()->setSuccessor(0, BB0);
+  DTU.applyUpdates({{DominatorTree::Delete, Preheader, SPH},
+                    {DominatorTree::Insert, Preheader, BB0}});
+
+  // (0) Check if we could be crossing a page boundary; if so, fallback to the
+  // old scalar loops. Also create a predicate of VF elements to be used in the
+  // vector loops.
+  Builder.SetInsertPoint(BB0);
+  Value *ISearchStart =
+      Builder.CreatePtrToInt(SearchStart, I64Ty, "search_start_int");
+  Value *ISearchEnd =
+      Builder.CreatePtrToInt(SearchEnd, I64Ty, "search_end_int");
+  Value *INeedleStart =
+      Builder.CreatePtrToInt(NeedleStart, I64Ty, "needle_start_int");
+  Value *INeedleEnd =
+      Builder.CreatePtrToInt(NeedleEnd, I64Ty, "needle_end_int");
+  Value *PredVF =
+      Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
+                              {ConstantInt::get(I64Ty, 0), ConstVF});
+
+  const uint64_t MinPageSize = TTI->getMinPageSize().value();
+  const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize);
+  Value *SearchStartPage =
+      Builder.CreateLShr(ISearchStart, AddrShiftAmt, "search_start_page");
+  Value *SearchEndPage =
+      Builder.CreateLShr(ISearchEnd, AddrShiftAmt, "search_end_page");
+  Value *NeedleStartPage =
+      Builder.CreateLShr(INeedleStart, AddrShiftAmt, "needle_start_page");
+  Value *NeedleEndPage =
+      Builder.CreateLShr(INeedleEnd, AddrShiftAmt, "needle_end_page");
+  Value *SearchPageCmp =
+      Builder.CreateICmpNE(SearchStartPage, SearchEndPage, "search_page_cmp");
+  Value *NeedlePageCmp =
+      Builder.CreateICmpNE(NeedleStartPage, NeedleEndPage, "needle_page_cmp");
+
+  Value *CombinedPageCmp =
+      Builder.CreateOr(SearchPageCmp, NeedlePageCmp, "combined_page_cmp");
+  BranchInst *CombinedPageBr = Builder.CreateCondBr(CombinedPageCmp, SPH, BB1);
+  CombinedPageBr->setMetadata(LLVMContext::MD_prof,
+                              MDBuilder(Ctx).createBranchWeights(10, 90));
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, BB0, SPH}, {DominatorTree::Insert, BB0, BB1}});
+
+  // (1) Load the search array and branch to the inner loop.
+  Builder.SetInsertPoint(BB1);
+  PHINode *Search = Builder.CreatePHI(PtrTy, 2, "psearch");
+  Value *PredSearch = Builder.CreateIntrinsic(
+      Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
+      {Builder.CreatePtrToInt(Search, I64Ty), ISearchEnd}, nullptr,
+      "search_pred");
+  PredSearch = Builder.CreateAnd(PredVF, PredSearch, "search_masked");
+  Value *LoadSearch = Builder.CreateMaskedLoad(
+      CharVTy, Search, Align(1), PredSearch, Passthru, "search_load_vec");
+  Builder.CreateBr(BB2);
+  DTU.applyUpdates({{DominatorTree::Insert, BB1, BB2}});
+
+  // (2) Inner loop.
+  Builder.SetInsertPoint(BB2);
+  PHINode *Needle = Builder.CreatePHI(PtrTy, 2, "pneedle");
+
+  // (2.a) Load the needle array.
+  Value *PredNeedle = Builder.CreateIntrinsic(
+      Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
+      {Builder.CreatePtrToInt(Needle, I64Ty), INeedleEnd}, nullptr,
+      "needle_pred");
+  PredNeedle = Builder.CreateAnd(PredVF, PredNeedle, "needle_masked");
+  Value *LoadNeedle = Builder.CreateMaskedLoad(
+      CharVTy, Needle, Align(1), PredNeedle, Passthru, "needle_load_vec");
+
+  // (2.b) Splat the first element to the inactive lanes.
+  Value *Needle0 =
+      Builder.CreateExtractElement(LoadNeedle, uint64_t(0), "needle0");
+  Value *Needle0Splat = Builder.CreateVectorSplat(ElementCount::getScalable(VF),
+                                                  Needle0, "needle0");
+  LoadNeedle = Builder.CreateSelect(PredNeedle, LoadNeedle, Needle0Splat,
+                                    "needle_splat");
+  LoadNeedle =
+      Builder.CreateExtractVector(FixedVectorType::get(CharTy, VF), LoadNeedle,
+                                  ConstantInt::get(I64Ty, 0), "needle_vec");
+
+  // (2.c) Test if there's a match.
+  Value *MatchPred = Builder.CreateIntrinsic(
+      Intrinsic::experimental_vector_match, {CharVTy, LoadNeedle->getType()},
+      {LoadSearch, LoadNeedle, PredSearch}, nullptr, "match_pred");
+  Value *IfAnyMatch = Builder.CreateOrReduce(MatchPred);
+  Builder.CreateCondBr(IfAnyMatch, BB3, BB4);
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, BB2, BB3}, {DominatorTree::Insert, BB2, BB4}});
+
+  // (3) We found a match. Compute the index of its location and exit.
+  Builder.SetInsertPoint(BB3);
+  PHINode *MatchLCSSA = Builder.CreatePHI(PtrTy, 1, "match_start");
+  PHINode *MatchPredLCSSA =
+      Builder.CreatePHI(MatchPred->getType(), 1, "match_vec");
+  Value *MatchCnt = Builder.CreateIntrinsic(
+      Intrinsic::experimental_cttz_elts, {I64Ty, MatchPred->getType()},
+      {MatchPredLCSSA, /*ZeroIsPoison=*/Builder.getInt1(true)}, nullptr,
+      "match_idx");
+  Value *MatchVal =
+      Builder.CreateGEP(CharTy, MatchLCSSA, MatchCnt, "match_res");
+  Builder.CreateBr(ExitSucc);
+  DTU.applyUpdates({{DominatorTree::Insert, BB3, ExitSucc}});
+
+  // (4) Check if we've reached the end of the needle array.
+  Builder.SetInsertPoint(BB4);
+  Value *NextNeedle =
+      Builder.CreateGEP(CharTy, Needle, ConstVF, "needle_next_vec");
+  Builder.CreateCondBr(Builder.CreateICmpULT(NextNeedle, NeedleEnd), BB2, BB5);
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, BB4, BB2}, {DominatorTree::Insert, BB4, BB5}});
+
+  // (5) Check if we've reached the end of the search array.
+  Builder.SetInsertPoint(BB5);
+  Value *NextSearch =
+      Builder.CreateGEP(CharTy, Search, ConstVF, "search_next_vec");
+  Builder.CreateCondBr(Builder.CreateICmpULT(NextSearch, SearchEnd), BB1,
+                       ExitFail);
+  DTU.applyUpdates({{DominatorTree::Insert, BB5, BB1},
+                    {DominatorTree::Insert, BB5, ExitFail}});
+
+  // Set up the PHI nodes.
+  Search->addIncoming(SearchStart, BB0);
+  Search->addIncoming(NextSearch, BB5);
+  Needle->addIncoming(NeedleStart, BB1);
+  Needle->addIncoming(NextNeedle, BB4);
+  // These are needed to retain LCSSA form.
+  MatchLCSSA->addIncoming(Search, BB2);
+  MatchPredLCSSA->addIncoming(MatchPred, BB2);
+
+  if (VerifyLoops) {
+    OuterLoop->verifyLoop();
+    InnerLoop->verifyLoop();
+    if (!OuterLoop->isRecursivelyLCSSAForm(*DT, *LI))
+      report_fatal_error("Loops must remain in LCSSA form!");
+  }
+
+  return MatchVal;
+}
+
+void LoopIdiomVectorize::transformFindFirstByte(
+    PHINode *IndPhi, unsigned VF, Type *CharTy, BasicBlock *ExitSucc,
+    BasicBlock *ExitFail, Value *SearchStart, Value *SearchEnd,
+    Value *NeedleStart, Value *NeedleEnd) {
+  // Insert the find first byte code at the end of the preheader block.
+  BasicBlock *Preheader = CurLoop->getLoopPreheader();
+  BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
+  IRBuilder<> Builder(PHBranch);
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
+
+  Value *MatchVal =
+      expandFindFirstByte(Builder, DTU, VF, CharTy, ExitSucc, ExitFail,
+                          SearchStart, SearchEnd, NeedleStart, NeedleEnd);
+
+  assert(PHBranch->isUnconditional() &&
+         "Expected preheader to terminate with an unconditional branch.");
+
+  // Add new incoming values with the result of the transformation to PHINodes
+  // of ExitSucc that use IndPhi.
+  for (auto *U : llvm::make_early_inc_range(IndPhi->users())) {
+    auto *PN = dyn_cast<PHINode>(U);
----------------
david-arm wrote:

Sorry my mistake! It's because it doesn't have the `if (!CurLoop->contains(cast<Instruction>(U)))` guard so it could be inside the loop.

https://github.com/llvm/llvm-project/pull/101976


More information about the llvm-commits mailing list