[llvm] [LoopIdiomVectorize][NFC] Factoring out the part that handles vectorization strategy (PR #94682)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 21 11:06:37 PDT 2024


https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/94682

>From c0b709210f29b7736bf990d8320efdbd95e67092 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Thu, 6 Jun 2024 13:20:55 -0700
Subject: [PATCH] [LoopIdiomVectorize][NFC] Factoring out the part that handles
 vectorization strategy

To pave the way for porting LIV to RISC-V, which uses VP intrinsics for
vectors.

NFC.
---
 .../Vectorize/LoopIdiomVectorize.cpp          | 240 ++++++++++--------
 1 file changed, 133 insertions(+), 107 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
index 38095b1433ebe..d4a417f4be8ad 100644
--- a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
@@ -78,6 +78,13 @@ class LoopIdiomVectorize {
   const TargetTransformInfo *TTI;
   const DataLayout *DL;
 
+  // Blocks that will be used for inserting vectorized code.
+  BasicBlock *EndBlock = nullptr;
+  BasicBlock *VectorLoopPreheaderBlock = nullptr;
+  BasicBlock *VectorLoopStartBlock = nullptr;
+  BasicBlock *VectorLoopMismatchBlock = nullptr;
+  BasicBlock *VectorLoopIncBlock = nullptr;
+
 public:
   explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI,
                               const TargetTransformInfo *TTI,
@@ -95,9 +102,16 @@ class LoopIdiomVectorize {
                       SmallVectorImpl<BasicBlock *> &ExitBlocks);
 
   bool recognizeByteCompare();
+
   Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
                             GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
                             Instruction *Index, Value *Start, Value *MaxLen);
+
+  Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
+                                  GetElementPtrInst *GEPA,
+                                  GetElementPtrInst *GEPB, Value *ExtStart,
+                                  Value *ExtEnd);
+
   void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
                             PHINode *IndPhi, Value *MaxLen, Instruction *Index,
                             Value *Start, bool IncIdx, BasicBlock *FoundBB,
@@ -331,6 +345,115 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
   return true;
 }
 
+Value *LoopIdiomVectorize::createMaskedFindMismatch(
+    IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
+    GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
+  Type *I64Type = Builder.getInt64Ty();
+  Type *ResType = Builder.getInt32Ty();
+  Type *LoadType = Builder.getInt8Ty();
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // At this point we know two things must be true:
+  //  1. Start <= End
+  //  2. ExtMaxLen <= MinPageSize due to the page checks.
+  // Therefore, we know that we can use a 64-bit induction variable that
+  // starts from 0 -> ExtMaxLen and it will not overflow.
+  ScalableVectorType *PredVTy =
+      ScalableVectorType::get(Builder.getInt1Ty(), 16);
+
+  Value *InitialPred = Builder.CreateIntrinsic(
+      Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
+
+  Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
+  VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
+                             /*HasNUW=*/true, /*HasNSW=*/true);
+
+  Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
+                                            Builder.getInt1(false));
+
+  BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
+  Builder.Insert(JumpToVectorLoop);
+
+  DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
+                     VectorLoopStartBlock}});
+
+  // Set up the first vector loop block by creating the PHIs, doing the vector
+  // loads and comparing the vectors.
+  Builder.SetInsertPoint(VectorLoopStartBlock);
+  PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
+  LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
+  PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
+  VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
+  Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
+  Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
+
+  Value *VectorLhsGep =
+      Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
+  Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
+                                                  Align(1), LoopPred, Passthru);
+
+  Value *VectorRhsGep =
+      Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
+  Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
+                                                  Align(1), LoopPred, Passthru);
+
+  Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
+  VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
+  Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
+  BranchInst *VectorEarlyExit = BranchInst::Create(
+      VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
+  Builder.Insert(VectorEarlyExit);
+
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
+       {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
+
+  // Increment the index counter and calculate the predicate for the next
+  // iteration of the loop. We branch back to the start of the loop if there
+  // is at least one active lane.
+  Builder.SetInsertPoint(VectorLoopIncBlock);
+  Value *NewVectorIndexPhi =
+      Builder.CreateAdd(VectorIndexPhi, VecLen, "",
+                        /*HasNUW=*/true, /*HasNSW=*/true);
+  VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
+  Value *NewPred =
+      Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
+                              {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
+  LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
+
+  Value *PredHasActiveLanes =
+      Builder.CreateExtractElement(NewPred, uint64_t(0));
+  BranchInst *VectorLoopBranchBack =
+      BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
+  Builder.Insert(VectorLoopBranchBack);
+
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
+       {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
+
+  // If we found a mismatch then we need to calculate which lane in the vector
+  // had a mismatch and add that on to the current loop index.
+  Builder.SetInsertPoint(VectorLoopMismatchBlock);
+  PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
+  FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
+  PHINode *LastLoopPred =
+      Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
+  LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
+  PHINode *VectorFoundIndex =
+      Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
+  VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
+
+  Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
+  Value *Ctz = Builder.CreateIntrinsic(
+      Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
+      {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
+  Ctz = Builder.CreateZExt(Ctz, I64Type);
+  Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
+                                             /*HasNUW=*/true, /*HasNSW=*/true);
+  return Builder.CreateTrunc(VectorLoopRes64, ResType);
+}
+
 Value *LoopIdiomVectorize::expandFindMismatch(
     IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
     GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -345,8 +468,7 @@ Value *LoopIdiomVectorize::expandFindMismatch(
   Type *ResType = Builder.getInt32Ty();
 
   // Split block in the original loop preheader.
-  BasicBlock *EndBlock =
-      SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
+  EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
 
   // Create the blocks that we're going to need:
   //  1. A block for checking the zero-extended length exceeds 0
@@ -370,17 +492,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
   BasicBlock *MemCheckBlock = BasicBlock::Create(
       Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
 
-  BasicBlock *VectorLoopPreheaderBlock = BasicBlock::Create(
+  VectorLoopPreheaderBlock = BasicBlock::Create(
       Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock);
 
-  BasicBlock *VectorLoopStartBlock = BasicBlock::Create(
-      Ctx, "mismatch_vec_loop", EndBlock->getParent(), EndBlock);
+  VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop",
+                                            EndBlock->getParent(), EndBlock);
 
-  BasicBlock *VectorLoopIncBlock = BasicBlock::Create(
-      Ctx, "mismatch_vec_loop_inc", EndBlock->getParent(), EndBlock);
+  VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc",
+                                          EndBlock->getParent(), EndBlock);
 
-  BasicBlock *VectorLoopMismatchBlock = BasicBlock::Create(
-      Ctx, "mismatch_vec_loop_found", EndBlock->getParent(), EndBlock);
+  VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found",
+                                               EndBlock->getParent(), EndBlock);
 
   BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
       Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
@@ -491,104 +613,8 @@ Value *LoopIdiomVectorize::expandFindMismatch(
   // processed in each iteration, etc.
   Builder.SetInsertPoint(VectorLoopPreheaderBlock);
 
-  // At this point we know two things must be true:
-  //  1. Start <= End
-  //  2. ExtMaxLen <= MinPageSize due to the page checks.
-  // Therefore, we know that we can use a 64-bit induction variable that
-  // starts from 0 -> ExtMaxLen and it will not overflow.
-  ScalableVectorType *PredVTy =
-      ScalableVectorType::get(Builder.getInt1Ty(), 16);
-
-  Value *InitialPred = Builder.CreateIntrinsic(
-      Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
-
-  Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
-  VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
-                             /*HasNUW=*/true, /*HasNSW=*/true);
-
-  Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
-                                            Builder.getInt1(false));
-
-  BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
-  Builder.Insert(JumpToVectorLoop);
-
-  DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
-                     VectorLoopStartBlock}});
-
-  // Set up the first vector loop block by creating the PHIs, doing the vector
-  // loads and comparing the vectors.
-  Builder.SetInsertPoint(VectorLoopStartBlock);
-  PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
-  LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
-  PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
-  VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
-  Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
-  Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
-
-  Value *VectorLhsGep =
-      Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
-  Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
-                                                  Align(1), LoopPred, Passthru);
-
-  Value *VectorRhsGep =
-      Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
-  Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
-                                                  Align(1), LoopPred, Passthru);
-
-  Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
-  VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
-  Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
-  BranchInst *VectorEarlyExit = BranchInst::Create(
-      VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
-  Builder.Insert(VectorEarlyExit);
-
-  DTU.applyUpdates(
-      {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
-       {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
-
-  // Increment the index counter and calculate the predicate for the next
-  // iteration of the loop. We branch back to the start of the loop if there
-  // is at least one active lane.
-  Builder.SetInsertPoint(VectorLoopIncBlock);
-  Value *NewVectorIndexPhi =
-      Builder.CreateAdd(VectorIndexPhi, VecLen, "",
-                        /*HasNUW=*/true, /*HasNSW=*/true);
-  VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
-  Value *NewPred =
-      Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
-                              {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
-  LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
-
-  Value *PredHasActiveLanes =
-      Builder.CreateExtractElement(NewPred, uint64_t(0));
-  BranchInst *VectorLoopBranchBack =
-      BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
-  Builder.Insert(VectorLoopBranchBack);
-
-  DTU.applyUpdates(
-      {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
-       {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
-
-  // If we found a mismatch then we need to calculate which lane in the vector
-  // had a mismatch and add that on to the current loop index.
-  Builder.SetInsertPoint(VectorLoopMismatchBlock);
-  PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
-  FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
-  PHINode *LastLoopPred =
-      Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
-  LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
-  PHINode *VectorFoundIndex =
-      Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
-  VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
-
-  Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
-  Value *Ctz = Builder.CreateIntrinsic(
-      Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
-      {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
-  Ctz = Builder.CreateZExt(Ctz, I64Type);
-  Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
-                                             /*HasNUW=*/true, /*HasNSW=*/true);
-  Value *VectorLoopRes = Builder.CreateTrunc(VectorLoopRes64, ResType);
+  Value *VectorLoopRes =
+      createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
 
   Builder.Insert(BranchInst::Create(EndBlock));
 



More information about the llvm-commits mailing list