[llvm] [RISCV][LoopIdiomVectorize] Support VP intrinsics in LoopIdiomVectorize (PR #94082)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 26 22:22:51 PDT 2024


================
@@ -331,6 +373,222 @@ bool LoopIdiomVectorize::recognizeByteCompare() {
   return true;
 }
 
+Value *LoopIdiomVectorize::createMaskedFindMismatch(
+    IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
+    GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
+  Type *I64Type = Builder.getInt64Ty();
+  Type *ResType = Builder.getInt32Ty();
+  Type *LoadType = Builder.getInt8Ty();
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  ScalableVectorType *PredVTy =
+      ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF);
+
+  Value *InitialPred = Builder.CreateIntrinsic(
+      Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
+
+  Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
+  VecLen =
+      Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
+                        /*HasNUW=*/true, /*HasNSW=*/true);
+
+  Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
+                                            Builder.getInt1(false));
+
+  BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
+  Builder.Insert(JumpToVectorLoop);
+
+  DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
+                     VectorLoopStartBlock}});
+
+  // Set up the first vector loop block by creating the PHIs, doing the vector
+  // loads and comparing the vectors.
+  Builder.SetInsertPoint(VectorLoopStartBlock);
+  PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
+  LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
+  PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
+  VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
+  Type *VectorLoadType =
+      ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF);
+  Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
+
+  Value *VectorLhsGep =
+      Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
+  Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
+                                                  Align(1), LoopPred, Passthru);
+
+  Value *VectorRhsGep =
+      Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
+  Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
+                                                  Align(1), LoopPred, Passthru);
+
+  Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
+  VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
+  Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
+  BranchInst *VectorEarlyExit = BranchInst::Create(
+      VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
+  Builder.Insert(VectorEarlyExit);
+
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
+       {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
+
+  // Increment the index counter and calculate the predicate for the next
+  // iteration of the loop. We branch back to the start of the loop if there
+  // is at least one active lane.
+  Builder.SetInsertPoint(VectorLoopIncBlock);
+  Value *NewVectorIndexPhi =
+      Builder.CreateAdd(VectorIndexPhi, VecLen, "",
+                        /*HasNUW=*/true, /*HasNSW=*/true);
+  VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
+  Value *NewPred =
+      Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
+                              {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
+  LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
+
+  Value *PredHasActiveLanes =
+      Builder.CreateExtractElement(NewPred, uint64_t(0));
+  BranchInst *VectorLoopBranchBack =
+      BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
+  Builder.Insert(VectorLoopBranchBack);
+
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
+       {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
+
+  // If we found a mismatch then we need to calculate which lane in the vector
+  // had a mismatch and add that on to the current loop index.
+  Builder.SetInsertPoint(VectorLoopMismatchBlock);
+  PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
+  FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
+  PHINode *LastLoopPred =
+      Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
+  LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
+  PHINode *VectorFoundIndex =
+      Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
+  VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
+
+  Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
+  Value *Ctz = Builder.CreateIntrinsic(
+      Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
+      {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
+  Ctz = Builder.CreateZExt(Ctz, I64Type);
+  Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
+                                             /*HasNUW=*/true, /*HasNSW=*/true);
+  return Builder.CreateTrunc(VectorLoopRes64, ResType);
+}
+
+Value *LoopIdiomVectorize::createPredicatedFindMismatch(
+    IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
+    GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
+  Type *I64Type = Builder.getInt64Ty();
+  Type *I32Type = Builder.getInt32Ty();
+  Type *ResType = I32Type;
+  Type *LoadType = Builder.getInt8Ty();
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
+  Builder.Insert(JumpToVectorLoop);
+
+  DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
+                     VectorLoopStartBlock}});
+
+  // Set up the first Vector loop block by creating the PHIs, doing the vector
+  // loads and comparing the vectors.
+  Builder.SetInsertPoint(VectorLoopStartBlock);
+  auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index");
+  VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
+
+  // Calculate AVL by subtracting the vector loop index from the trip count
+  Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true,
+                                 /*HasNSW=*/true);
+
+  auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF);
+  auto *VF = ConstantInt::get(
+      I32Type, VectorLoadType->getElementCount().getKnownMinValue());
----------------
topperc wrote:

Can we just pass ByteCompareVF here? Why do we need to extract it from `VectorLoadType`?

https://github.com/llvm/llvm-project/pull/94082


More information about the llvm-commits mailing list