[llvm] [AArch64] Add MATCH loops to LoopIdiomVectorizePass (PR #101976)

Ricardo Jesus via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 28 08:20:27 PST 2024


================
@@ -939,3 +988,361 @@ void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
       report_fatal_error("Loops must remain in LCSSA form!");
   }
 }
+
+bool LoopIdiomVectorize::recognizeFindFirstByte() {
+  // Currently the transformation only works on scalable vector types, although
+  // there is no fundamental reason why it cannot be made to work for fixed
+  // vectors too.
+  if (!TTI->supportsScalableVectors() || DisableFindFirstByte)
+    return false;
+
+  // Define some constants we need throughout.
+  BasicBlock *Header = CurLoop->getHeader();
+  LLVMContext &Ctx = Header->getContext();
+
+  // We are expecting the blocks below. For now, we will bail out for almost
+  // anything other than this.
+  //
+  // Header:
+  //   %14 = phi ptr [ %24, %OuterBB ], [ %3, %Header.preheader ]
+  //   %15 = load i8, ptr %14, align 1
+  //   br label %MatchBB
+  //
+  // MatchBB:
+  //   %20 = phi ptr [ %7, %Header ], [ %17, %InnerBB ]
+  //   %21 = load i8, ptr %20, align 1
+  //   %22 = icmp eq i8 %15, %21
+  //   br i1 %22, label %ExitSucc, label %InnerBB
+  //
+  // InnerBB:
+  //   %17 = getelementptr inbounds i8, ptr %20, i64 1
+  //   %18 = icmp eq ptr %17, %10
+  //   br i1 %18, label %OuterBB, label %MatchBB
+  //
+  // OuterBB:
+  //   %24 = getelementptr inbounds i8, ptr %14, i64 1
+  //   %25 = icmp eq ptr %24, %6
+  //   br i1 %25, label %ExitFail, label %Header
+
+  // We expect the four blocks above, which include one nested loop.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4 ||
+      CurLoop->getSubLoops().size() != 1)
+    return false;
+
+  auto *InnerLoop = CurLoop->getSubLoops().front();
+  PHINode *IndPhi = dyn_cast<PHINode>(&Header->front());
+  if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
+    return false;
+
+  // Check instruction counts.
+  auto LoopBlocks = CurLoop->getBlocks();
+  if (LoopBlocks[0]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[1]->sizeWithoutDebug() > 4 ||
+      LoopBlocks[2]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[3]->sizeWithoutDebug() > 3)
+    return false;
+
+  // Check that no instruction other than IndPhi has outside uses.
+  for (BasicBlock *BB : LoopBlocks)
+    for (Instruction &I : *BB)
+      if (&I != IndPhi)
+        for (User *U : I.users())
+          if (!CurLoop->contains(cast<Instruction>(U)))
+            return false;
+
+  // Match the branch instruction in the header. We are expecting an
+  // unconditional branch to the inner loop.
+  BasicBlock *MatchBB;
+  if (!match(Header->getTerminator(), m_UnconditionalBr(MatchBB)) ||
+      !InnerLoop->contains(MatchBB))
+    return false;
+
+  // MatchBB should be the entrypoint into the inner loop containing the
+  // comparison between a search element and a needle.
+  BasicBlock *ExitSucc, *InnerBB;
+  Value *LoadA, *LoadB;
+  ICmpInst::Predicate MatchPred;
+  if (!match(MatchBB->getTerminator(),
+             m_Br(m_ICmp(MatchPred, m_Value(LoadA), m_Value(LoadB)),
+                  m_BasicBlock(ExitSucc), m_BasicBlock(InnerBB))) ||
+      MatchPred != ICmpInst::Predicate::ICMP_EQ ||
+      !InnerLoop->contains(InnerBB))
+    return false;
+
+  // We expect outside uses of `IndPhi' in ExitSucc (and only there).
+  for (User *U : IndPhi->users())
+    if (!CurLoop->contains(cast<Instruction>(U)))
+      if (auto *PN = dyn_cast<PHINode>(U); !PN || PN->getParent() != ExitSucc)
+        return false;
+
+  // Match the loads and check they are simple.
+  Value *A, *B;
+  if (!match(LoadA, m_Load(m_Value(A))) || !cast<LoadInst>(LoadA)->isSimple() ||
+      !match(LoadB, m_Load(m_Value(B))) || !cast<LoadInst>(LoadB)->isSimple())
+    return false;
+
+  // Check we are loading valid characters.
+  Type *CharTy = LoadA->getType();
+  if (!CharTy->isIntegerTy() || LoadB->getType() != CharTy)
+    return false;
+
+  // Choose the vectorisation factor, work out the cost of the match intrinsic
+  // and decide if we should use it.
+  // Note: VF could be parameterised, but 128-bit vectors sounds like a good
+  // default choice for the time being.
+  unsigned VF = 128 / CharTy->getIntegerBitWidth();
+  SmallVector<Type *> Args = {
+      ScalableVectorType::get(CharTy, VF), FixedVectorType::get(CharTy, VF),
+      ScalableVectorType::get(Type::getInt1Ty(Ctx), VF)};
+  IntrinsicCostAttributes Attrs(Intrinsic::experimental_vector_match, Args[2],
+                                Args);
+  if (TTI->getIntrinsicInstrCost(Attrs, TTI::TCK_SizeAndLatency) > 4)
+    return false;
+
+  // The loads come from two PHIs, each with two incoming values.
+  PHINode *PNA = dyn_cast<PHINode>(A);
+  PHINode *PNB = dyn_cast<PHINode>(B);
+  if (!PNA || PNA->getNumIncomingValues() != 2 || !PNB ||
+      PNB->getNumIncomingValues() != 2)
+    return false;
+
+  // One PHI comes from the outer loop (PNA), the other one from the inner loop
+  // (PNB). PNA effectively corresponds to IndPhi.
+  if (InnerLoop->contains(PNA))
+    std::swap(PNA, PNB);
+  if (PNA != &Header->front() || PNB != &MatchBB->front())
+    return false;
+
+  // The incoming values of both PHI nodes should be a gep of 1.
+  Value *StartA = PNA->getIncomingValue(0);
+  Value *IndexA = PNA->getIncomingValue(1);
+  if (CurLoop->contains(PNA->getIncomingBlock(0)))
+    std::swap(StartA, IndexA);
+
+  Value *StartB = PNB->getIncomingValue(0);
+  Value *IndexB = PNB->getIncomingValue(1);
+  if (InnerLoop->contains(PNB->getIncomingBlock(0)))
+    std::swap(StartB, IndexB);
+
+  // Match the GEPs.
+  if (!match(IndexA, m_GEP(m_Specific(PNA), m_One())) ||
+      !match(IndexB, m_GEP(m_Specific(PNB), m_One())))
+    return false;
+
+  // Check their result type matches `CharTy'.
+  GetElementPtrInst *GEPA = cast<GetElementPtrInst>(IndexA);
+  GetElementPtrInst *GEPB = cast<GetElementPtrInst>(IndexB);
+  if (GEPA->getResultElementType() != CharTy ||
+      GEPB->getResultElementType() != CharTy)
+    return false;
+
+  // InnerBB should increment the address of the needle pointer.
+  BasicBlock *OuterBB;
+  Value *EndB;
+  if (!match(InnerBB->getTerminator(),
+             m_Br(m_ICmp(MatchPred, m_Specific(GEPB), m_Value(EndB)),
+                  m_BasicBlock(OuterBB), m_Specific(MatchBB))) ||
+      MatchPred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(OuterBB))
+    return false;
+
+  // OuterBB should increment the address of the search element pointer.
+  BasicBlock *ExitFail;
+  Value *EndA;
+  if (!match(OuterBB->getTerminator(),
+             m_Br(m_ICmp(MatchPred, m_Specific(GEPA), m_Value(EndA)),
+                  m_BasicBlock(ExitFail), m_Specific(Header))) ||
+      MatchPred != ICmpInst::Predicate::ICMP_EQ)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" << *CurLoop << "\n\n");
----------------
rj-jesus wrote:

Thanks, I think this was from `LoopIdiomVectorize::recognizeByteCompare` but I've formatted it properly now.

https://github.com/llvm/llvm-project/pull/101976


More information about the llvm-commits mailing list