[llvm] [RISCV][LoopIdiomVectorize] Support VP intrinsics in LoopIdiomVectorize (PR #94082)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 26 22:22:50 PDT 2024
================
@@ -331,6 +373,222 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
return true;
}
+Value *LoopIdiomVectorize::createMaskedFindMismatch(
+ IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
+ GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
+ Type *I64Type = Builder.getInt64Ty();
+ Type *ResType = Builder.getInt32Ty();
+ Type *LoadType = Builder.getInt8Ty();
+ Value *PtrA = GEPA->getPointerOperand();
+ Value *PtrB = GEPB->getPointerOperand();
+
+ ScalableVectorType *PredVTy =
+ ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF);
+
+ Value *InitialPred = Builder.CreateIntrinsic(
+ Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
+
+ Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
+ VecLen =
+ Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+
+ Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
+ Builder.getInt1(false));
+
+ BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
+ Builder.Insert(JumpToVectorLoop);
+
+ DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
+ VectorLoopStartBlock}});
+
+ // Set up the first vector loop block by creating the PHIs, doing the vector
+ // loads and comparing the vectors.
+ Builder.SetInsertPoint(VectorLoopStartBlock);
+ PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
+ LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
+ PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
+ VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
+ Type *VectorLoadType =
+ ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF);
+ Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
+
+ Value *VectorLhsGep =
+ Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
+ Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
+ Align(1), LoopPred, Passthru);
+
+ Value *VectorRhsGep =
+ Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
+ Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
+ Align(1), LoopPred, Passthru);
+
+ Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
+ VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
+ Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
+ BranchInst *VectorEarlyExit = BranchInst::Create(
+ VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
+ Builder.Insert(VectorEarlyExit);
+
+ DTU.applyUpdates(
+ {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
+ {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
+
+ // Increment the index counter and calculate the predicate for the next
+ // iteration of the loop. We branch back to the start of the loop if there
+ // is at least one active lane.
+ Builder.SetInsertPoint(VectorLoopIncBlock);
+ Value *NewVectorIndexPhi =
+ Builder.CreateAdd(VectorIndexPhi, VecLen, "",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
+ Value *NewPred =
+ Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
+ {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
+ LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
+
+ Value *PredHasActiveLanes =
+ Builder.CreateExtractElement(NewPred, uint64_t(0));
+ BranchInst *VectorLoopBranchBack =
+ BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
+ Builder.Insert(VectorLoopBranchBack);
+
+ DTU.applyUpdates(
+ {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
+ {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
+
+ // If we found a mismatch then we need to calculate which lane in the vector
+ // had a mismatch and add that on to the current loop index.
+ Builder.SetInsertPoint(VectorLoopMismatchBlock);
+ PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
+ FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
+ PHINode *LastLoopPred =
+ Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
+ LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
+ PHINode *VectorFoundIndex =
+ Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
+ VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
+
+ Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
+ Value *Ctz = Builder.CreateIntrinsic(
+ Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
+ {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
+ Ctz = Builder.CreateZExt(Ctz, I64Type);
+ Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ return Builder.CreateTrunc(VectorLoopRes64, ResType);
+}
+
+Value *LoopIdiomVectorize::createPredicatedFindMismatch(
+ IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
+ GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
+ Type *I64Type = Builder.getInt64Ty();
+ Type *I32Type = Builder.getInt32Ty();
+ Type *ResType = I32Type;
+ Type *LoadType = Builder.getInt8Ty();
+ Value *PtrA = GEPA->getPointerOperand();
+ Value *PtrB = GEPB->getPointerOperand();
+
+ auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
+ Builder.Insert(JumpToVectorLoop);
+
+ DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
+ VectorLoopStartBlock}});
+
+ // Set up the first Vector loop block by creating the PHIs, doing the vector
+ // loads and comparing the vectors.
+ Builder.SetInsertPoint(VectorLoopStartBlock);
+ auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index");
+ VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
+
+ // Calculate AVL by subtracting the vector loop index from the trip count
+ Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true,
+ /*HasNSW=*/true);
+
+ auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF);
+ auto *VF = ConstantInt::get(
+ I32Type, VectorLoadType->getElementCount().getKnownMinValue());
+ auto *IsScalable = ConstantInt::getBool(
+ Builder.getContext(), VectorLoadType->getElementCount().isScalable());
----------------
topperc wrote:
Can we just pass `true` here instead of extract from the type?
https://github.com/llvm/llvm-project/pull/94082
More information about the llvm-commits
mailing list