[llvm] [LoopIdiomVectorize][NFC] Factoring out the part that handles vectorization strategy (PR #94682)
Min-Yih Hsu via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 21 11:06:37 PDT 2024
https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/94682
>From c0b709210f29b7736bf990d8320efdbd95e67092 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Thu, 6 Jun 2024 13:20:55 -0700
Subject: [PATCH] [LoopIdiomVectorize][NFC] Factoring out the part that handles
vectorization strategy
To pave the way for porting LIV to RISC-V, which uses VP intrinsics for
vectors.
NFC.
---
.../Vectorize/LoopIdiomVectorize.cpp | 240 ++++++++++--------
1 file changed, 133 insertions(+), 107 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
index 38095b1433ebe..d4a417f4be8ad 100644
--- a/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
@@ -78,6 +78,13 @@ class LoopIdiomVectorize {
const TargetTransformInfo *TTI;
const DataLayout *DL;
+ // Blocks that will be used for inserting vectorized code.
+ BasicBlock *EndBlock = nullptr;
+ BasicBlock *VectorLoopPreheaderBlock = nullptr;
+ BasicBlock *VectorLoopStartBlock = nullptr;
+ BasicBlock *VectorLoopMismatchBlock = nullptr;
+ BasicBlock *VectorLoopIncBlock = nullptr;
+
public:
explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI,
const TargetTransformInfo *TTI,
@@ -95,9 +102,16 @@ class LoopIdiomVectorize {
SmallVectorImpl<BasicBlock *> &ExitBlocks);
bool recognizeByteCompare();
+
Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
Instruction *Index, Value *Start, Value *MaxLen);
+
+ Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
+ GetElementPtrInst *GEPA,
+ GetElementPtrInst *GEPB, Value *ExtStart,
+ Value *ExtEnd);
+
void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
PHINode *IndPhi, Value *MaxLen, Instruction *Index,
Value *Start, bool IncIdx, BasicBlock *FoundBB,
@@ -331,6 +345,115 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
return true;
}
+Value *LoopIdiomVectorize::createMaskedFindMismatch(
+ IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
+ GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
+ Type *I64Type = Builder.getInt64Ty();
+ Type *ResType = Builder.getInt32Ty();
+ Type *LoadType = Builder.getInt8Ty();
+ Value *PtrA = GEPA->getPointerOperand();
+ Value *PtrB = GEPB->getPointerOperand();
+
+ // At this point we know two things must be true:
+ // 1. Start <= End
+ // 2. ExtMaxLen <= MinPageSize due to the page checks.
+ // Therefore, we know that we can use a 64-bit induction variable that
+ // starts from 0 -> ExtMaxLen and it will not overflow.
+ ScalableVectorType *PredVTy =
+ ScalableVectorType::get(Builder.getInt1Ty(), 16);
+
+ Value *InitialPred = Builder.CreateIntrinsic(
+ Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
+
+ Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
+ VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+
+ Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
+ Builder.getInt1(false));
+
+ BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
+ Builder.Insert(JumpToVectorLoop);
+
+ DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
+ VectorLoopStartBlock}});
+
+ // Set up the first vector loop block by creating the PHIs, doing the vector
+ // loads and comparing the vectors.
+ Builder.SetInsertPoint(VectorLoopStartBlock);
+ PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
+ LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
+ PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
+ VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
+ Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
+ Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
+
+ Value *VectorLhsGep =
+ Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
+ Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
+ Align(1), LoopPred, Passthru);
+
+ Value *VectorRhsGep =
+ Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
+ Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
+ Align(1), LoopPred, Passthru);
+
+ Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
+ VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
+ Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
+ BranchInst *VectorEarlyExit = BranchInst::Create(
+ VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
+ Builder.Insert(VectorEarlyExit);
+
+ DTU.applyUpdates(
+ {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
+ {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
+
+ // Increment the index counter and calculate the predicate for the next
+ // iteration of the loop. We branch back to the start of the loop if there
+ // is at least one active lane.
+ Builder.SetInsertPoint(VectorLoopIncBlock);
+ Value *NewVectorIndexPhi =
+ Builder.CreateAdd(VectorIndexPhi, VecLen, "",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
+ Value *NewPred =
+ Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
+ {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
+ LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
+
+ Value *PredHasActiveLanes =
+ Builder.CreateExtractElement(NewPred, uint64_t(0));
+ BranchInst *VectorLoopBranchBack =
+ BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
+ Builder.Insert(VectorLoopBranchBack);
+
+ DTU.applyUpdates(
+ {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
+ {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
+
+ // If we found a mismatch then we need to calculate which lane in the vector
+ // had a mismatch and add that on to the current loop index.
+ Builder.SetInsertPoint(VectorLoopMismatchBlock);
+ PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
+ FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
+ PHINode *LastLoopPred =
+ Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
+ LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
+ PHINode *VectorFoundIndex =
+ Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
+ VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
+
+ Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
+ Value *Ctz = Builder.CreateIntrinsic(
+ Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
+ {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
+ Ctz = Builder.CreateZExt(Ctz, I64Type);
+ Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ return Builder.CreateTrunc(VectorLoopRes64, ResType);
+}
+
Value *LoopIdiomVectorize::expandFindMismatch(
IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
@@ -345,8 +468,7 @@ Value *LoopIdiomVectorize::expandFindMismatch(
Type *ResType = Builder.getInt32Ty();
// Split block in the original loop preheader.
- BasicBlock *EndBlock =
- SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
+ EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
// Create the blocks that we're going to need:
// 1. A block for checking the zero-extended length exceeds 0
@@ -370,17 +492,17 @@ Value *LoopIdiomVectorize::expandFindMismatch(
BasicBlock *MemCheckBlock = BasicBlock::Create(
Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
- BasicBlock *VectorLoopPreheaderBlock = BasicBlock::Create(
+ VectorLoopPreheaderBlock = BasicBlock::Create(
Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock);
- BasicBlock *VectorLoopStartBlock = BasicBlock::Create(
- Ctx, "mismatch_vec_loop", EndBlock->getParent(), EndBlock);
+ VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop",
+ EndBlock->getParent(), EndBlock);
- BasicBlock *VectorLoopIncBlock = BasicBlock::Create(
- Ctx, "mismatch_vec_loop_inc", EndBlock->getParent(), EndBlock);
+ VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc",
+ EndBlock->getParent(), EndBlock);
- BasicBlock *VectorLoopMismatchBlock = BasicBlock::Create(
- Ctx, "mismatch_vec_loop_found", EndBlock->getParent(), EndBlock);
+ VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found",
+ EndBlock->getParent(), EndBlock);
BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
@@ -491,104 +613,8 @@ Value *LoopIdiomVectorize::expandFindMismatch(
// processed in each iteration, etc.
Builder.SetInsertPoint(VectorLoopPreheaderBlock);
- // At this point we know two things must be true:
- // 1. Start <= End
- // 2. ExtMaxLen <= MinPageSize due to the page checks.
- // Therefore, we know that we can use a 64-bit induction variable that
- // starts from 0 -> ExtMaxLen and it will not overflow.
- ScalableVectorType *PredVTy =
- ScalableVectorType::get(Builder.getInt1Ty(), 16);
-
- Value *InitialPred = Builder.CreateIntrinsic(
- Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
-
- Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
- VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
- /*HasNUW=*/true, /*HasNSW=*/true);
-
- Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
- Builder.getInt1(false));
-
- BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
- Builder.Insert(JumpToVectorLoop);
-
- DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
- VectorLoopStartBlock}});
-
- // Set up the first vector loop block by creating the PHIs, doing the vector
- // loads and comparing the vectors.
- Builder.SetInsertPoint(VectorLoopStartBlock);
- PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
- LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
- PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
- VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
- Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
- Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
-
- Value *VectorLhsGep =
- Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
- Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
- Align(1), LoopPred, Passthru);
-
- Value *VectorRhsGep =
- Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
- Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
- Align(1), LoopPred, Passthru);
-
- Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
- VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
- Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
- BranchInst *VectorEarlyExit = BranchInst::Create(
- VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
- Builder.Insert(VectorEarlyExit);
-
- DTU.applyUpdates(
- {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
- {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
-
- // Increment the index counter and calculate the predicate for the next
- // iteration of the loop. We branch back to the start of the loop if there
- // is at least one active lane.
- Builder.SetInsertPoint(VectorLoopIncBlock);
- Value *NewVectorIndexPhi =
- Builder.CreateAdd(VectorIndexPhi, VecLen, "",
- /*HasNUW=*/true, /*HasNSW=*/true);
- VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
- Value *NewPred =
- Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
- {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
- LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
-
- Value *PredHasActiveLanes =
- Builder.CreateExtractElement(NewPred, uint64_t(0));
- BranchInst *VectorLoopBranchBack =
- BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
- Builder.Insert(VectorLoopBranchBack);
-
- DTU.applyUpdates(
- {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
- {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
-
- // If we found a mismatch then we need to calculate which lane in the vector
- // had a mismatch and add that on to the current loop index.
- Builder.SetInsertPoint(VectorLoopMismatchBlock);
- PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
- FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
- PHINode *LastLoopPred =
- Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
- LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
- PHINode *VectorFoundIndex =
- Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
- VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
-
- Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
- Value *Ctz = Builder.CreateIntrinsic(
- Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
- {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
- Ctz = Builder.CreateZExt(Ctz, I64Type);
- Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
- /*HasNUW=*/true, /*HasNSW=*/true);
- Value *VectorLoopRes = Builder.CreateTrunc(VectorLoopRes64, ResType);
+ Value *VectorLoopRes =
+ createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
Builder.Insert(BranchInst::Create(EndBlock));
More information about the llvm-commits
mailing list