[llvm] [AArch64] Add MATCH loops to LoopIdiomVectorizePass (PR #101976)

Ricardo Jesus via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 21 01:10:26 PST 2024


https://github.com/rj-jesus updated https://github.com/llvm/llvm-project/pull/101976

>From f2e0f08bb333c8560615b5e11d73827bb2f781b7 Mon Sep 17 00:00:00 2001
From: Ricardo Jesus <rjj at nvidia.com>
Date: Mon, 15 Jul 2024 17:57:30 +0100
Subject: [PATCH] [AArch64] Add MATCH loops to LoopIdiomVectorizePass

This patch adds a new loop to LoopIdiomVectorizePass, enabling it to
recognise and use @llvm.experimental.vector.match to vectorise loops
such as:

    char* find_first_of(char *first, char *last,
                        char *s_first, char *s_last) {
      for (; first != last; ++first)
        for (char *it = s_first; it != s_last; ++it)
          if (*first == *it)
            return first;
      return last;
    }

These loops match the C++ standard library's std::find_first_of.
---
 .../llvm/Analysis/TargetTransformInfoImpl.h   |   4 +-
 .../AArch64/AArch64TargetTransformInfo.cpp    |  17 +
 .../Vectorize/LoopIdiomVectorize.cpp          | 425 +++++++++++++++++-
 llvm/test/CodeGen/AArch64/find-first-byte.ll  | 123 +++++
 4 files changed, 559 insertions(+), 10 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/find-first-byte.ll

diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 72038c090b7922..d5c76f7f6a5ee2 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -775,7 +775,9 @@ class TargetTransformInfoImplBase {
     default:
       break;
     case Intrinsic::experimental_vector_histogram_add:
-      // For now, we want explicit support from the target for histograms.
+    case Intrinsic::experimental_vector_match:
+      // For now, we want explicit support from the target for histograms and
+      // matches.
       return InstructionCost::getInvalid();
     case Intrinsic::allow_runtime_check:
     case Intrinsic::allow_ubsan_check:
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index ec7bb71fd111ff..839563e31cb87e 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -914,6 +914,23 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
     }
     break;
   }
+  case Intrinsic::experimental_vector_match: {
+    EVT SearchVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
+    unsigned SearchSize =
+        cast<FixedVectorType>(ICA.getArgTypes()[1])->getNumElements();
+    // If we can't lower to MATCH, return an invalid cost.
+    if (getTLI()->shouldExpandVectorMatch(SearchVT, SearchSize))
+      return InstructionCost::getInvalid();
+    // Base cost for MATCH instructions. At least on the Neoverse V2 and
+    // Neoverse V3 these are cheap operations with the same latency as a vector
+    // ADD, though in most cases we also need to do an extra DUP.
+    InstructionCost Cost = 4;
+    // For fixed-length vectors we currently need an extra five--six
+    // instructions besides the MATCH.
+    if (isa<FixedVectorType>(RetTy))
+      Cost += 6;
+    return Cost;
+  }
   default:
     break;
   }
diff --git a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
index 7af7408ed67a8c..dbc2f55e2c0ec8 100644
--- a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
@@ -10,8 +10,10 @@
 // transforms them into more optimized versions of the same loop. In cases
 // where this happens, it can be a significant performance win.
 //
-// We currently only recognize one loop that finds the first mismatched byte
-// in an array and returns the index, i.e. something like:
+// We currently support two loops:
+//
+// 1. A loop that finds the first mismatched byte in an array and returns the
+// index, i.e. something like:
 //
 //  while (++i != n) {
 //    if (a[i] != b[i])
@@ -24,12 +26,6 @@
 // boundaries. However, even with these checks it is still profitable to do the
 // transformation.
 //
-//===----------------------------------------------------------------------===//
-//
-// NOTE: This Pass matches a really specific loop pattern because it's only
-// supposed to be a temporary solution until our LoopVectorizer is powerful
-// enought to vectorize it automatically.
-//
 // TODO List:
 //
 // * Add support for the inverse case where we scan for a matching element.
@@ -37,6 +33,35 @@
 // * Recognize loops that increment the IV *after* comparing bytes.
 // * Allow 32-bit sign-extends of the IV used by the GEP.
 //
+// 2. A loop that finds the first matching character in an array among a set of
+// possible matches, e.g.:
+//
+//   for (; first != last; ++first)
+//     for (s_it = s_first; s_it != s_last; ++s_it)
+//       if (*first == *s_it)
+//         return first;
+//   return last;
+//
+// This corresponds to std::find_first_of (for arrays of bytes) from the C++
+// standard library. This function can be implemented efficiently for targets
+// that support @llvm.experimental.vector.match. For example, on AArch64 targets
+// that implement SVE2, this lower to a MATCH instruction, which enables us to
+// perform up to 16x16=256 comparisons in one go. This can lead to very
+// significant speedups.
+//
+// TODO:
+//
+// * Add support for `find_first_not_of' loops (i.e. with not-equal comparison).
+// * Make VF a configurable parameter (right now we assume 128-bit vectors).
+// * Potentially adjust the cost model to let the transformation kick-in even if
+//   @llvm.experimental.vector.match doesn't have direct support in hardware.
+//
+//===----------------------------------------------------------------------===//
+//
+// NOTE: This Pass matches really specific loop patterns because it's only
+// supposed to be a temporary solution until our LoopVectorizer is powerful
+// enought to vectorize them automatically.
+//
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
@@ -79,6 +104,12 @@ static cl::opt<unsigned>
               cl::desc("The vectorization factor for byte-compare patterns."),
               cl::init(16));
 
+static cl::opt<bool>
+    DisableFindFirstByte("disable-loop-idiom-vectorize-find-first-byte",
+                         cl::Hidden, cl::init(false),
+                         cl::desc("Proceed with Loop Idiom Vectorize Pass, but "
+                                  "do not convert find-first-byte loop(s)."));
+
 static cl::opt<bool>
     VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false),
                 cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
@@ -136,6 +167,18 @@ class LoopIdiomVectorize {
                             PHINode *IndPhi, Value *MaxLen, Instruction *Index,
                             Value *Start, bool IncIdx, BasicBlock *FoundBB,
                             BasicBlock *EndBB);
+
+  bool recognizeFindFirstByte();
+
+  Value *expandFindFirstByte(IRBuilder<> &Builder, DomTreeUpdater &DTU,
+                             unsigned VF, Type *CharTy, BasicBlock *ExitSucc,
+                             BasicBlock *ExitFail, Value *StartA, Value *EndA,
+                             Value *StartB, Value *EndB);
+
+  void transformFindFirstByte(PHINode *IndPhi, unsigned VF, Type *CharTy,
+                              BasicBlock *ExitSucc, BasicBlock *ExitFail,
+                              Value *StartA, Value *EndA, Value *StartB,
+                              Value *EndB);
   /// @}
 };
 } // anonymous namespace
@@ -190,7 +233,13 @@ bool LoopIdiomVectorize::run(Loop *L) {
   LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %"
                     << CurLoop->getHeader()->getName() << "\n");
 
-  return recognizeByteCompare();
+  if (recognizeByteCompare())
+    return true;
+
+  if (recognizeFindFirstByte())
+    return true;
+
+  return false;
 }
 
 bool LoopIdiomVectorize::recognizeByteCompare() {
@@ -939,3 +988,361 @@ void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
       report_fatal_error("Loops must remain in LCSSA form!");
   }
 }
+
+bool LoopIdiomVectorize::recognizeFindFirstByte() {
+  // Currently the transformation only works on scalable vector types, although
+  // there is no fundamental reason why it cannot be made to work for fixed
+  // vectors too.
+  if (!TTI->supportsScalableVectors() || DisableFindFirstByte)
+    return false;
+
+  // Define some constants we need throughout.
+  BasicBlock *Header = CurLoop->getHeader();
+  LLVMContext &Ctx = Header->getContext();
+
+  // We are expecting the blocks below. For now, we will bail out for almost
+  // anything other than this.
+  //
+  // Header:
+  //   %14 = phi ptr [ %24, %OuterBB ], [ %3, %Header.preheader ]
+  //   %15 = load i8, ptr %14, align 1
+  //   br label %MatchBB
+  //
+  // MatchBB:
+  //   %20 = phi ptr [ %7, %Header ], [ %17, %InnerBB ]
+  //   %21 = load i8, ptr %20, align 1
+  //   %22 = icmp eq i8 %15, %21
+  //   br i1 %22, label %ExitSucc, label %InnerBB
+  //
+  // InnerBB:
+  //   %17 = getelementptr inbounds i8, ptr %20, i64 1
+  //   %18 = icmp eq ptr %17, %10
+  //   br i1 %18, label %OuterBB, label %MatchBB
+  //
+  // OuterBB:
+  //   %24 = getelementptr inbounds i8, ptr %14, i64 1
+  //   %25 = icmp eq ptr %24, %6
+  //   br i1 %25, label %ExitFail, label %Header
+
+  // We expect the four blocks above, which include one nested loop.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4 ||
+      CurLoop->getSubLoops().size() != 1)
+    return false;
+
+  auto *InnerLoop = CurLoop->getSubLoops().front();
+  PHINode *IndPhi = dyn_cast<PHINode>(&Header->front());
+  if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
+    return false;
+
+  // Check instruction counts.
+  auto LoopBlocks = CurLoop->getBlocks();
+  if (LoopBlocks[0]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[1]->sizeWithoutDebug() > 4 ||
+      LoopBlocks[2]->sizeWithoutDebug() > 3 ||
+      LoopBlocks[3]->sizeWithoutDebug() > 3)
+    return false;
+
+  // Check that no instruction other than IndPhi has outside uses.
+  for (BasicBlock *BB : LoopBlocks)
+    for (Instruction &I : *BB)
+      if (&I != IndPhi)
+        for (User *U : I.users())
+          if (!CurLoop->contains(cast<Instruction>(U)))
+            return false;
+
+  // Match the branch instruction in the header. We are expecting an
+  // unconditional branch to the inner loop.
+  BasicBlock *MatchBB;
+  if (!match(Header->getTerminator(), m_UnconditionalBr(MatchBB)) ||
+      !InnerLoop->contains(MatchBB))
+    return false;
+
+  // MatchBB should be the entrypoint into the inner loop containing the
+  // comparison between a search element and a needle.
+  BasicBlock *ExitSucc, *InnerBB;
+  Value *LoadA, *LoadB;
+  ICmpInst::Predicate MatchPred;
+  if (!match(MatchBB->getTerminator(),
+             m_Br(m_ICmp(MatchPred, m_Value(LoadA), m_Value(LoadB)),
+                  m_BasicBlock(ExitSucc), m_BasicBlock(InnerBB))) ||
+      MatchPred != ICmpInst::Predicate::ICMP_EQ ||
+      !InnerLoop->contains(InnerBB))
+    return false;
+
+  // We expect outside uses of `IndPhi' in ExitSucc (and only there).
+  for (User *U : IndPhi->users())
+    if (!CurLoop->contains(cast<Instruction>(U)))
+      if (auto *PN = dyn_cast<PHINode>(U); !PN || PN->getParent() != ExitSucc)
+        return false;
+
+  // Match the loads and check they are simple.
+  Value *A, *B;
+  if (!match(LoadA, m_Load(m_Value(A))) || !cast<LoadInst>(LoadA)->isSimple() ||
+      !match(LoadB, m_Load(m_Value(B))) || !cast<LoadInst>(LoadB)->isSimple())
+    return false;
+
+  // Check we are loading valid characters.
+  Type *CharTy = LoadA->getType();
+  if (!CharTy->isIntegerTy() || LoadB->getType() != CharTy)
+    return false;
+
+  // Choose the vectorisation factor, work out the cost of the match intrinsic
+  // and decide if we should use it.
+  // Note: VF could be parameterised, but 128-bit vectors sounds like a good
+  // default choice for the time being.
+  unsigned VF = 128 / CharTy->getIntegerBitWidth();
+  SmallVector<Type *> Args = {
+      ScalableVectorType::get(CharTy, VF), FixedVectorType::get(CharTy, VF),
+      ScalableVectorType::get(Type::getInt1Ty(Ctx), VF)};
+  IntrinsicCostAttributes Attrs(Intrinsic::experimental_vector_match, Args[2],
+                                Args);
+  if (TTI->getIntrinsicInstrCost(Attrs, TTI::TCK_SizeAndLatency) > 4)
+    return false;
+
+  // The loads come from two PHIs, each with two incoming values.
+  PHINode *PNA = dyn_cast<PHINode>(A);
+  PHINode *PNB = dyn_cast<PHINode>(B);
+  if (!PNA || PNA->getNumIncomingValues() != 2 || !PNB ||
+      PNB->getNumIncomingValues() != 2)
+    return false;
+
+  // One PHI comes from the outer loop (PNA), the other one from the inner loop
+  // (PNB). PNA effectively corresponds to IndPhi.
+  if (InnerLoop->contains(PNA))
+    std::swap(PNA, PNB);
+  if (PNA != &Header->front() || PNB != &MatchBB->front())
+    return false;
+
+  // The incoming values of both PHI nodes should be a gep of 1.
+  Value *StartA = PNA->getIncomingValue(0);
+  Value *IndexA = PNA->getIncomingValue(1);
+  if (CurLoop->contains(PNA->getIncomingBlock(0)))
+    std::swap(StartA, IndexA);
+
+  Value *StartB = PNB->getIncomingValue(0);
+  Value *IndexB = PNB->getIncomingValue(1);
+  if (InnerLoop->contains(PNB->getIncomingBlock(0)))
+    std::swap(StartB, IndexB);
+
+  // Match the GEPs.
+  if (!match(IndexA, m_GEP(m_Specific(PNA), m_One())) ||
+      !match(IndexB, m_GEP(m_Specific(PNB), m_One())))
+    return false;
+
+  // Check their result type matches `CharTy'.
+  GetElementPtrInst *GEPA = cast<GetElementPtrInst>(IndexA);
+  GetElementPtrInst *GEPB = cast<GetElementPtrInst>(IndexB);
+  if (GEPA->getResultElementType() != CharTy ||
+      GEPB->getResultElementType() != CharTy)
+    return false;
+
+  // InnerBB should increment the address of the needle pointer.
+  BasicBlock *OuterBB;
+  Value *EndB;
+  if (!match(InnerBB->getTerminator(),
+             m_Br(m_ICmp(MatchPred, m_Specific(GEPB), m_Value(EndB)),
+                  m_BasicBlock(OuterBB), m_Specific(MatchBB))) ||
+      MatchPred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(OuterBB))
+    return false;
+
+  // OuterBB should increment the address of the search element pointer.
+  BasicBlock *ExitFail;
+  Value *EndA;
+  if (!match(OuterBB->getTerminator(),
+             m_Br(m_ICmp(MatchPred, m_Specific(GEPA), m_Value(EndA)),
+                  m_BasicBlock(ExitFail), m_Specific(Header))) ||
+      MatchPred != ICmpInst::Predicate::ICMP_EQ)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n" << *CurLoop << "\n\n");
+
+  transformFindFirstByte(IndPhi, VF, CharTy, ExitSucc, ExitFail, StartA, EndA,
+                         StartB, EndB);
+  return true;
+}
+
+Value *LoopIdiomVectorize::expandFindFirstByte(
+    IRBuilder<> &Builder, DomTreeUpdater &DTU, unsigned VF, Type *CharTy,
+    BasicBlock *ExitSucc, BasicBlock *ExitFail, Value *StartA, Value *EndA,
+    Value *StartB, Value *EndB) {
+  // Set up some types and constants that we intend to reuse.
+  auto *PtrTy = Builder.getPtrTy();
+  auto *I64Ty = Builder.getInt64Ty();
+  auto *PredVTy = ScalableVectorType::get(Builder.getInt1Ty(), VF);
+  auto *CharVTy = ScalableVectorType::get(CharTy, VF);
+  auto *ConstVF = ConstantInt::get(I64Ty, VF);
+
+  // Other common arguments.
+  BasicBlock *Preheader = CurLoop->getLoopPreheader();
+  LLVMContext &Ctx = Preheader->getContext();
+  Value *Passthru = ConstantInt::getNullValue(CharVTy);
+
+  // Split block in the original loop preheader.
+  // SPH is the new preheader to the old scalar loop.
+  BasicBlock *SPH = SplitBlock(Preheader, Preheader->getTerminator(), DT, LI,
+                               nullptr, "scalar_ph");
+
+  // Create the blocks that we're going to use.
+  //
+  // We will have the following loops:
+  // (O) Outer loop where we iterate over the elements of the search array (A).
+  // (I) Inner loop where we iterate over the elements of the needle array (B).
+  //
+  // Overall, the blocks do the following:
+  // (1) Load a vector's worth of A. Go to (2).
+  // (2) (a) Load a vector's worth of B.
+  //     (b) Splat the first element of B to the inactive lanes.
+  //     (c) Check if any elements match. If so go to (3), otherwise go to (4).
+  // (3) Compute the index of the first match and exit.
+  // (4) Check if we've reached the end of B. If not loop back to (2), otherwise
+  //     go to (5).
+  // (5) Check if we've reached the end of A. If not loop back to (1), otherwise
+  //     exit.
+  // Block (3) is not part of any loop. Blocks (1,5) and (2,4) belong to the
+  // outer and inner loops, respectively.
+  BasicBlock *BB1 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
+  BasicBlock *BB2 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
+  BasicBlock *BB3 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
+  BasicBlock *BB4 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
+  BasicBlock *BB5 = BasicBlock::Create(Ctx, "", SPH->getParent(), SPH);
+
+  // Update LoopInfo with the new loops.
+  auto OL = LI->AllocateLoop();
+  auto IL = LI->AllocateLoop();
+
+  if (auto ParentLoop = CurLoop->getParentLoop()) {
+    ParentLoop->addChildLoop(OL);
+    ParentLoop->addBasicBlockToLoop(BB3, *LI);
+  } else {
+    LI->addTopLevelLoop(OL);
+  }
+
+  // Add the inner loop to the outer.
+  OL->addChildLoop(IL);
+
+  // Add the new basic blocks to the corresponding loops.
+  OL->addBasicBlockToLoop(BB1, *LI);
+  OL->addBasicBlockToLoop(BB5, *LI);
+  IL->addBasicBlockToLoop(BB2, *LI);
+  IL->addBasicBlockToLoop(BB4, *LI);
+
+  // Set a reference to the old scalar loop and create a predicate of VF
+  // elements.
+  Builder.SetInsertPoint(Preheader->getTerminator());
+  Value *Pred16 =
+      Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
+                              {ConstantInt::get(I64Ty, 0), ConstVF});
+  Builder.CreateCondBr(Builder.getFalse(), SPH, BB1);
+  Preheader->getTerminator()->eraseFromParent();
+  DTU.applyUpdates({{DominatorTree::Insert, Preheader, BB1}});
+
+  // (1) Load a vector's worth of A and branch to the inner loop.
+  Builder.SetInsertPoint(BB1);
+  PHINode *PNA = Builder.CreatePHI(PtrTy, 2, "pa");
+  Value *PredA =
+      Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
+                              {Builder.CreatePointerCast(PNA, I64Ty),
+                               Builder.CreatePointerCast(EndA, I64Ty)});
+  PredA = Builder.CreateAnd(Pred16, PredA);
+  Value *LoadA =
+      Builder.CreateMaskedLoad(CharVTy, PNA, Align(1), PredA, Passthru);
+  Builder.CreateBr(BB2);
+  DTU.applyUpdates({{DominatorTree::Insert, BB1, BB2}});
+
+  // (2) Inner loop.
+  Builder.SetInsertPoint(BB2);
+  PHINode *PNB = Builder.CreatePHI(PtrTy, 2, "pb");
+
+  // (2.a) Load a vector's worth of B.
+  Value *PredB =
+      Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
+                              {Builder.CreatePointerCast(PNB, I64Ty),
+                               Builder.CreatePointerCast(EndB, I64Ty)});
+  PredB = Builder.CreateAnd(Pred16, PredB);
+  Value *LoadB =
+      Builder.CreateMaskedLoad(CharVTy, PNB, Align(1), PredB, Passthru);
+
+  // (2.b) Splat the first element to the inactive lanes.
+  Value *LoadB0 = Builder.CreateExtractElement(LoadB, uint64_t(0));
+  Value *LoadB0Splat =
+      Builder.CreateVectorSplat(ElementCount::getScalable(VF), LoadB0);
+  LoadB = Builder.CreateSelect(PredB, LoadB, LoadB0Splat);
+  LoadB = Builder.CreateExtractVector(FixedVectorType::get(CharTy, VF), LoadB,
+                                      ConstantInt::get(I64Ty, 0));
+
+  // (2.c) Test if there's a match.
+  Value *MatchPred = Builder.CreateIntrinsic(
+      Intrinsic::experimental_vector_match, {CharVTy, LoadB->getType()},
+      {LoadA, LoadB, PredA});
+  Value *IfAnyMatch = Builder.CreateOrReduce(MatchPred);
+  Builder.CreateCondBr(IfAnyMatch, BB3, BB4);
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, BB2, BB3}, {DominatorTree::Insert, BB2, BB4}});
+
+  // (3) We found a match. Compute the index of its location and exit.
+  Builder.SetInsertPoint(BB3);
+  Value *MatchCnt = Builder.CreateIntrinsic(
+      Intrinsic::experimental_cttz_elts, {I64Ty, MatchPred->getType()},
+      {MatchPred, /*ZeroIsPoison=*/Builder.getInt1(true)});
+  Value *MatchVal = Builder.CreateGEP(CharTy, PNA, MatchCnt);
+  Builder.CreateBr(ExitSucc);
+  DTU.applyUpdates({{DominatorTree::Insert, BB3, ExitSucc}});
+
+  // (4) Check if we've reached the end of B.
+  Builder.SetInsertPoint(BB4);
+  Value *IncB = Builder.CreateGEP(CharTy, PNB, ConstVF);
+  Builder.CreateCondBr(Builder.CreateICmpULT(IncB, EndB), BB2, BB5);
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, BB4, BB2}, {DominatorTree::Insert, BB4, BB5}});
+
+  // (5) Check if we've reached the end of A.
+  Builder.SetInsertPoint(BB5);
+  Value *IncA = Builder.CreateGEP(CharTy, PNA, ConstVF);
+  Builder.CreateCondBr(Builder.CreateICmpULT(IncA, EndA), BB1, ExitFail);
+  DTU.applyUpdates({{DominatorTree::Insert, BB5, BB1},
+                    {DominatorTree::Insert, BB5, ExitFail}});
+
+  // Set up the PHI's.
+  PNA->addIncoming(StartA, Preheader);
+  PNA->addIncoming(IncA, BB5);
+  PNB->addIncoming(StartB, BB1);
+  PNB->addIncoming(IncB, BB4);
+
+  if (VerifyLoops) {
+    OL->verifyLoop();
+    IL->verifyLoop();
+    if (!OL->isRecursivelyLCSSAForm(*DT, *LI))
+      report_fatal_error("Loops must remain in LCSSA form!");
+  }
+
+  return MatchVal;
+}
+
+void LoopIdiomVectorize::transformFindFirstByte(PHINode *IndPhi, unsigned VF,
+                                                Type *CharTy,
+                                                BasicBlock *ExitSucc,
+                                                BasicBlock *ExitFail,
+                                                Value *StartA, Value *EndA,
+                                                Value *StartB, Value *EndB) {
+  // Insert the find first byte code at the end of the preheader block.
+  BasicBlock *Preheader = CurLoop->getLoopPreheader();
+  BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
+  IRBuilder<> Builder(PHBranch);
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
+
+  Value *MatchVal = expandFindFirstByte(Builder, DTU, VF, CharTy, ExitSucc,
+                                        ExitFail, StartA, EndA, StartB, EndB);
+
+  // Add new incoming values with the result of the transformation to PHINodes
+  // of ExitSucc that use IndPhi.
+  for (auto *U : llvm::make_early_inc_range(IndPhi->users()))
+    if (auto *PN = dyn_cast<PHINode>(U); PN && PN->getParent() == ExitSucc)
+      PN->addIncoming(MatchVal, cast<Instruction>(MatchVal)->getParent());
+
+  if (VerifyLoops && CurLoop->getParentLoop()) {
+    CurLoop->getParentLoop()->verifyLoop();
+    if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI))
+      report_fatal_error("Loops must remain in LCSSA form!");
+  }
+}
diff --git a/llvm/test/CodeGen/AArch64/find-first-byte.ll b/llvm/test/CodeGen/AArch64/find-first-byte.ll
new file mode 100644
index 00000000000000..e60553e95e13cf
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/find-first-byte.ll
@@ -0,0 +1,123 @@
+; RUN: opt -mattr=+sve2 -mtriple=aarch64 -passes='loop(loop-idiom-vectorize)' -S < %s | FileCheck -check-prefix=SVE2 %s
+; RUN: opt -mattr=-sve2 -mtriple=aarch64 -passes='loop(loop-idiom-vectorize)' -S < %s | FileCheck -check-prefix=NOSVE2 %s
+
+; Base case based on `libcxx/include/__algorithm/find_first_of.h':
+;   char* find_first_of(char *first, char *last, char *s_first, char *s_last) {
+;     for (; first != last; ++first)
+;       for (char *it = s_first; it != s_last; ++it)
+;         if (*first == *it)
+;           return first;
+;     return last;
+;   }
+define ptr @find_first_of_i8(ptr %0, ptr %1, ptr %2, ptr %3) #0 {
+; SVE2-LABEL: define ptr @find_first_of_i8(
+; SVE2:         {{%.*}} = call <vscale x 16 x i1> @llvm.experimental.vector.match.nxv16i8.v16i8(<vscale x 16 x i8> {{%.*}}, <16 x i8> {{%.*}}, <vscale x 16 x i1> {{%.*}})
+;
+; NOSVE2-LABEL: define ptr @find_first_of_i8(
+; NOSVE2-NOT:     {{%.*}} = call <vscale x 16 x i1> @llvm.experimental.vector.match.nxv16i8.v16i8(<vscale x 16 x i8> {{%.*}}, <16 x i8> {{%.*}}, <vscale x 16 x i1> {{%.*}})
+;
+  %5 = icmp eq ptr %0, %1
+  %6 = icmp eq ptr %2, %3
+  %7 = or i1 %5, %6
+  br i1 %7, label %21, label %8
+
+8:
+  %9 = phi ptr [ %19, %18 ], [ %0, %4 ]
+  %10 = load i8, ptr %9, align 1
+  br label %14
+
+11:
+  %12 = getelementptr inbounds i8, ptr %15, i64 1
+  %13 = icmp eq ptr %12, %3
+  br i1 %13, label %18, label %14
+
+14:
+  %15 = phi ptr [ %2, %8 ], [ %12, %11 ]
+  %16 = load i8, ptr %15, align 1
+  %17 = icmp eq i8 %10, %16
+  br i1 %17, label %21, label %11
+
+18:
+  %19 = getelementptr inbounds i8, ptr %9, i64 1
+  %20 = icmp eq ptr %19, %1
+  br i1 %20, label %21, label %8
+
+21:
+  %22 = phi ptr [ %1, %4 ], [ %9, %14 ], [ %1, %18 ]
+  ret ptr %22
+}
+
+; Same as @find_first_of_i8 but with i16.
+define ptr @find_first_of_i16(ptr %0, ptr %1, ptr %2, ptr %3) #0 {
+; SVE2-LABEL: define ptr @find_first_of_i16(
+; SVE2:         {{%.*}} = call <vscale x 8 x i1> @llvm.experimental.vector.match.nxv8i16.v8i16(<vscale x 8 x i16> {{%.*}}, <8 x i16> {{%.*}}, <vscale x 8 x i1> {{%.*}})
+;
+  %5 = icmp eq ptr %0, %1
+  %6 = icmp eq ptr %2, %3
+  %7 = or i1 %5, %6
+  br i1 %7, label %21, label %8
+
+8:
+  %9 = phi ptr [ %19, %18 ], [ %0, %4 ]
+  %10 = load i16, ptr %9, align 1
+  br label %14
+
+11:
+  %12 = getelementptr inbounds i16, ptr %15, i64 1
+  %13 = icmp eq ptr %12, %3
+  br i1 %13, label %18, label %14
+
+14:
+  %15 = phi ptr [ %2, %8 ], [ %12, %11 ]
+  %16 = load i16, ptr %15, align 1
+  %17 = icmp eq i16 %10, %16
+  br i1 %17, label %21, label %11
+
+18:
+  %19 = getelementptr inbounds i16, ptr %9, i64 1
+  %20 = icmp eq ptr %19, %1
+  br i1 %20, label %21, label %8
+
+21:
+  %22 = phi ptr [ %1, %4 ], [ %9, %14 ], [ %1, %18 ]
+  ret ptr %22
+}
+
+; Same as @find_first_of_i8 but with `ne' comparison.
+; This is rejected for now, but should eventually be supported.
+define ptr @find_first_not_of_i8(ptr %0, ptr %1, ptr %2, ptr %3) #0 {
+; SVE2-LABEL: define ptr @find_first_not_of_i8(
+; SVE2-NOT:     {{%.*}} = call <vscale x 16 x i1> @llvm.experimental.vector.match.nxv16i8.v16i8(<vscale x 16 x i8> {{%.*}}, <16 x i8> {{%.*}}, <vscale x 16 x i1> {{%.*}})
+;
+  %5 = icmp eq ptr %0, %1
+  %6 = icmp eq ptr %2, %3
+  %7 = or i1 %5, %6
+  br i1 %7, label %21, label %8
+
+8:
+  %9 = phi ptr [ %19, %18 ], [ %0, %4 ]
+  %10 = load i8, ptr %9, align 1
+  br label %14
+
+11:
+  %12 = getelementptr inbounds i8, ptr %15, i64 1
+  %13 = icmp eq ptr %12, %3
+  br i1 %13, label %18, label %14
+
+14:
+  %15 = phi ptr [ %2, %8 ], [ %12, %11 ]
+  %16 = load i8, ptr %15, align 1
+  %17 = icmp ne i8 %10, %16
+  br i1 %17, label %21, label %11
+
+18:
+  %19 = getelementptr inbounds i8, ptr %9, i64 1
+  %20 = icmp eq ptr %19, %1
+  br i1 %20, label %21, label %8
+
+21:
+  %22 = phi ptr [ %1, %4 ], [ %9, %14 ], [ %1, %18 ]
+  ret ptr %22
+}
+
+attributes #0 = { "target-features"="+sve2" }



More information about the llvm-commits mailing list