[llvm] [LoopIdiomVectorize][NFC] Factoring out the part that handles vectorization strategy (PR #94682)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 6 13:47:45 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Min-Yih Hsu (mshockwave)

<details>
<summary>Changes</summary>

To pave the way for porting LIV to RISC-V, which uses VP intrinsics for vectors.

NFC.

-----
This PR stacks on top of #<!-- -->94681

---

Patch is 100.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94682.diff


11 Files Affected:

- (renamed) llvm/include/llvm/Transforms/Vectorize/LoopIdiomVectorize.h (+5-9) 
- (modified) llvm/lib/Passes/PassBuilder.cpp (+1) 
- (modified) llvm/lib/Passes/PassRegistry.def (+1) 
- (modified) llvm/lib/Target/AArch64/AArch64.h (-1) 
- (removed) llvm/lib/Target/AArch64/AArch64PassRegistry.def (-20) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetMachine.cpp (+2-6) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetMachine.h (-1) 
- (modified) llvm/lib/Target/AArch64/CMakeLists.txt (+1-1) 
- (modified) llvm/lib/Transforms/Vectorize/CMakeLists.txt (+1) 
- (renamed) llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp (+203-250) 
- (modified) llvm/test/Transforms/LoopIdiom/AArch64/byte-compare-index.ll (+201-204) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.h b/llvm/include/llvm/Transforms/Vectorize/LoopIdiomVectorize.h
similarity index 60%
rename from llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.h
rename to llvm/include/llvm/Transforms/Vectorize/LoopIdiomVectorize.h
index cc68425bb68b5..56f44b7dc6b2a 100644
--- a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.h
+++ b/llvm/include/llvm/Transforms/Vectorize/LoopIdiomVectorize.h
@@ -1,4 +1,4 @@
-//===- AArch64LoopIdiomTransform.h --------------------------------------===//
+//===----------LoopIdiomVectorize.h -----------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,20 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef LLVM_LIB_TARGET_AARCH64_AARCH64LOOPIDIOMTRANSFORM_H
-#define LLVM_LIB_TARGET_AARCH64_AARCH64LOOPIDIOMTRANSFORM_H
+#ifndef LLVM_LIB_TRANSFORMS_VECTORIZE_LOOPIDIOMVECTORIZE_H
+#define LLVM_LIB_TRANSFORMS_VECTORIZE_LOOPIDIOMVECTORIZE_H
 
 #include "llvm/IR/PassManager.h"
 #include "llvm/Transforms/Scalar/LoopPassManager.h"
 
 namespace llvm {
-
-struct AArch64LoopIdiomTransformPass
-    : PassInfoMixin<AArch64LoopIdiomTransformPass> {
+struct LoopIdiomVectorizePass : PassInfoMixin<LoopIdiomVectorizePass> {
   PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM,
                         LoopStandardAnalysisResults &AR, LPMUpdater &U);
 };
-
 } // namespace llvm
-
-#endif // LLVM_LIB_TARGET_AARCH64_AARCH64LOOPIDIOMTRANSFORM_H
+#endif // LLVM_LIB_TRANSFORMS_VECTORIZE_LOOPIDIOMVECTORIZE_H
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 316d05bf1dc37..9d8c3c4d7bdee 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -297,6 +297,7 @@
 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
 #include "llvm/Transforms/Utils/UnifyLoopExits.h"
 #include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
+#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
 #include "llvm/Transforms/Vectorize/LoopVectorize.h"
 #include "llvm/Transforms/Vectorize/SLPVectorizer.h"
 #include "llvm/Transforms/Vectorize/VectorCombine.h"
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 50682ca4970f1..f71745a77a19b 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -621,6 +621,7 @@ LOOP_PASS("invalidate<all>", InvalidateAllAnalysesPass())
 LOOP_PASS("loop-bound-split", LoopBoundSplitPass())
 LOOP_PASS("loop-deletion", LoopDeletionPass())
 LOOP_PASS("loop-idiom", LoopIdiomRecognizePass())
+LOOP_PASS("loop-idiom-vectorize", LoopIdiomVectorizePass())
 LOOP_PASS("loop-instsimplify", LoopInstSimplifyPass())
 LOOP_PASS("loop-predication", LoopPredicationPass())
 LOOP_PASS("loop-reduce", LoopStrengthReducePass())
diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index 0f0a22ec82936..6f2aeb83a451a 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -90,7 +90,6 @@ void initializeAArch64DeadRegisterDefinitionsPass(PassRegistry&);
 void initializeAArch64ExpandPseudoPass(PassRegistry &);
 void initializeAArch64GlobalsTaggingPass(PassRegistry &);
 void initializeAArch64LoadStoreOptPass(PassRegistry&);
-void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &);
 void initializeAArch64LowerHomogeneousPrologEpilogPass(PassRegistry &);
 void initializeAArch64MIPeepholeOptPass(PassRegistry &);
 void initializeAArch64O0PreLegalizerCombinerPass(PassRegistry &);
diff --git a/llvm/lib/Target/AArch64/AArch64PassRegistry.def b/llvm/lib/Target/AArch64/AArch64PassRegistry.def
deleted file mode 100644
index ca944579f93a9..0000000000000
--- a/llvm/lib/Target/AArch64/AArch64PassRegistry.def
+++ /dev/null
@@ -1,20 +0,0 @@
-//===- AArch64PassRegistry.def - Registry of AArch64 passes -----*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file is used as the registry of passes that are part of the
-// AArch64 backend.
-//
-//===----------------------------------------------------------------------===//
-
-// NOTE: NO INCLUDE GUARD DESIRED!
-
-#ifndef LOOP_PASS
-#define LOOP_PASS(NAME, CREATE_PASS)
-#endif
-LOOP_PASS("aarch64-lit", AArch64LoopIdiomTransformPass())
-#undef LOOP_PASS
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index 30f0ceaf674c6..7de9071476e7f 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -11,7 +11,6 @@
 
 #include "AArch64TargetMachine.h"
 #include "AArch64.h"
-#include "AArch64LoopIdiomTransform.h"
 #include "AArch64MachineFunctionInfo.h"
 #include "AArch64MachineScheduler.h"
 #include "AArch64MacroFusion.h"
@@ -52,6 +51,7 @@
 #include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/CFGuard.h"
 #include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
 #include <memory>
 #include <optional>
 #include <string>
@@ -234,7 +234,6 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeAArch64Target() {
   initializeAArch64DeadRegisterDefinitionsPass(*PR);
   initializeAArch64ExpandPseudoPass(*PR);
   initializeAArch64LoadStoreOptPass(*PR);
-  initializeAArch64LoopIdiomTransformLegacyPassPass(*PR);
   initializeAArch64MIPeepholeOptPass(*PR);
   initializeAArch64SIMDInstrOptPass(*PR);
   initializeAArch64O0PreLegalizerCombinerPass(*PR);
@@ -553,12 +552,9 @@ class AArch64PassConfig : public TargetPassConfig {
 void AArch64TargetMachine::registerPassBuilderCallbacks(
     PassBuilder &PB, bool PopulateClassToPassNames) {
 
-#define GET_PASS_REGISTRY "AArch64PassRegistry.def"
-#include "llvm/Passes/TargetPassRegistry.inc"
-
   PB.registerLateLoopOptimizationsEPCallback(
       [=](LoopPassManager &LPM, OptimizationLevel Level) {
-        LPM.addPass(AArch64LoopIdiomTransformPass());
+        LPM.addPass(LoopIdiomVectorizePass());
       });
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.h b/llvm/lib/Target/AArch64/AArch64TargetMachine.h
index 8fb68b06f1378..e396d9204716a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.h
@@ -14,7 +14,6 @@
 #define LLVM_LIB_TARGET_AARCH64_AARCH64TARGETMACHINE_H
 
 #include "AArch64InstrInfo.h"
-#include "AArch64LoopIdiomTransform.h"
 #include "AArch64Subtarget.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/Target/TargetMachine.h"
diff --git a/llvm/lib/Target/AArch64/CMakeLists.txt b/llvm/lib/Target/AArch64/CMakeLists.txt
index 8e76f6c9279e7..639bc0707dff2 100644
--- a/llvm/lib/Target/AArch64/CMakeLists.txt
+++ b/llvm/lib/Target/AArch64/CMakeLists.txt
@@ -65,7 +65,6 @@ add_llvm_target(AArch64CodeGen
   AArch64ISelLowering.cpp
   AArch64InstrInfo.cpp
   AArch64LoadStoreOptimizer.cpp
-  AArch64LoopIdiomTransform.cpp
   AArch64LowerHomogeneousPrologEpilog.cpp
   AArch64MachineFunctionInfo.cpp
   AArch64MachineScheduler.cpp
@@ -112,6 +111,7 @@ add_llvm_target(AArch64CodeGen
   Target
   TargetParser
   TransformUtils
+  Vectorize
 
   ADD_TO_COMPONENT
   AArch64
diff --git a/llvm/lib/Transforms/Vectorize/CMakeLists.txt b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
index 9674094024b9e..4caec07c5ac43 100644
--- a/llvm/lib/Transforms/Vectorize/CMakeLists.txt
+++ b/llvm/lib/Transforms/Vectorize/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_llvm_component_library(LLVMVectorize
   LoadStoreVectorizer.cpp
+  LoopIdiomVectorize.cpp
   LoopVectorizationLegality.cpp
   LoopVectorize.cpp
   SLPVectorizer.cpp
diff --git a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
similarity index 71%
rename from llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
rename to llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
index 8ae3f014d45e0..6b6b067db30b7 100644
--- a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp
@@ -1,4 +1,4 @@
-//===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===//
+//===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -26,6 +26,10 @@
 //
 //===----------------------------------------------------------------------===//
 //
+// NOTE: This Pass matches a really specific loop pattern because it's only
+// supposed to be a temporary solution until our LoopVectorizer is powerful
+// enought to vectorize it automatically.
+//
 // TODO List:
 //
 // * Add support for the inverse case where we scan for a matching element.
@@ -35,7 +39,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "AArch64LoopIdiomTransform.h"
+#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Analysis/DomTreeUpdater.h"
 #include "llvm/Analysis/LoopPass.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
@@ -44,47 +49,47 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 
 using namespace llvm;
 using namespace PatternMatch;
 
-#define DEBUG_TYPE "aarch64-loop-idiom-transform"
-
-static cl::opt<bool>
-    DisableAll("disable-aarch64-lit-all", cl::Hidden, cl::init(false),
-               cl::desc("Disable AArch64 Loop Idiom Transform Pass."));
-
-static cl::opt<bool> DisableByteCmp(
-    "disable-aarch64-lit-bytecmp", cl::Hidden, cl::init(false),
-    cl::desc("Proceed with AArch64 Loop Idiom Transform Pass, but do "
-             "not convert byte-compare loop(s)."));
+#define DEBUG_TYPE "loop-idiom-vectorize"
 
-static cl::opt<bool> VerifyLoops(
-    "aarch64-lit-verify", cl::Hidden, cl::init(false),
-    cl::desc("Verify loops generated AArch64 Loop Idiom Transform Pass."));
+static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
+                                cl::init(false),
+                                cl::desc("Disable Loop Idiom Vectorize Pass."));
 
-namespace llvm {
-
-void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &);
-Pass *createAArch64LoopIdiomTransformPass();
+static cl::opt<bool>
+    DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden,
+                   cl::init(false),
+                   cl::desc("Proceed with Loop Idiom Vectorize Pass, but do "
+                            "not convert byte-compare loop(s)."));
 
-} // end namespace llvm
+static cl::opt<bool>
+    VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false),
+                cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
 
 namespace {
 
-class AArch64LoopIdiomTransform {
+class LoopIdiomVectorize {
   Loop *CurLoop = nullptr;
   DominatorTree *DT;
   LoopInfo *LI;
   const TargetTransformInfo *TTI;
   const DataLayout *DL;
 
+  // Blocks that will be used for inserting vectorized code.
+  BasicBlock *EndBlock = nullptr;
+  BasicBlock *VectorLoopPreheaderBlock = nullptr;
+  BasicBlock *VectorLoopStartBlock = nullptr;
+  BasicBlock *VectorLoopMismatchBlock = nullptr;
+  BasicBlock *VectorLoopIncBlock = nullptr;
+
 public:
-  explicit AArch64LoopIdiomTransform(DominatorTree *DT, LoopInfo *LI,
-                                     const TargetTransformInfo *TTI,
-                                     const DataLayout *DL)
+  explicit LoopIdiomVectorize(DominatorTree *DT, LoopInfo *LI,
+                              const TargetTransformInfo *TTI,
+                              const DataLayout *DL)
       : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
 
   bool run(Loop *L);
@@ -98,83 +103,32 @@ class AArch64LoopIdiomTransform {
                       SmallVectorImpl<BasicBlock *> &ExitBlocks);
 
   bool recognizeByteCompare();
+
   Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
                             GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
                             Instruction *Index, Value *Start, Value *MaxLen);
+
+  Value *createMaskedFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA,
+                                  GetElementPtrInst *GEPB, Value *ExtStart,
+                                  Value *ExtEnd);
+
   void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
                             PHINode *IndPhi, Value *MaxLen, Instruction *Index,
                             Value *Start, bool IncIdx, BasicBlock *FoundBB,
                             BasicBlock *EndBB);
   /// @}
 };
+} // anonymous namespace
 
-class AArch64LoopIdiomTransformLegacyPass : public LoopPass {
-public:
-  static char ID;
-
-  explicit AArch64LoopIdiomTransformLegacyPass() : LoopPass(ID) {
-    initializeAArch64LoopIdiomTransformLegacyPassPass(
-        *PassRegistry::getPassRegistry());
-  }
-
-  StringRef getPassName() const override {
-    return "Transform AArch64-specific loop idioms";
-  }
-
-  void getAnalysisUsage(AnalysisUsage &AU) const override {
-    AU.addRequired<LoopInfoWrapperPass>();
-    AU.addRequired<DominatorTreeWrapperPass>();
-    AU.addRequired<TargetTransformInfoWrapperPass>();
-  }
-
-  bool runOnLoop(Loop *L, LPPassManager &LPM) override;
-};
-
-bool AArch64LoopIdiomTransformLegacyPass::runOnLoop(Loop *L,
-                                                    LPPassManager &LPM) {
-
-  if (skipLoop(L))
-    return false;
-
-  auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
-  auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
-  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
-      *L->getHeader()->getParent());
-  return AArch64LoopIdiomTransform(
-             DT, LI, &TTI, &L->getHeader()->getModule()->getDataLayout())
-      .run(L);
-}
-
-} // end anonymous namespace
-
-char AArch64LoopIdiomTransformLegacyPass::ID = 0;
-
-INITIALIZE_PASS_BEGIN(
-    AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
-    "Transform specific loop idioms into optimized vector forms", false, false)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
-INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
-INITIALIZE_PASS_END(
-    AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
-    "Transform specific loop idioms into optimized vector forms", false, false)
-
-Pass *llvm::createAArch64LoopIdiomTransformPass() {
-  return new AArch64LoopIdiomTransformLegacyPass();
-}
-
-PreservedAnalyses
-AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
-                                   LoopStandardAnalysisResults &AR,
-                                   LPMUpdater &) {
+PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
+                                              LoopStandardAnalysisResults &AR,
+                                              LPMUpdater &) {
   if (DisableAll)
     return PreservedAnalyses::all();
 
   const auto *DL = &L.getHeader()->getModule()->getDataLayout();
 
-  AArch64LoopIdiomTransform LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
+  LoopIdiomVectorize LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
   if (!LIT.run(&L))
     return PreservedAnalyses::all();
 
@@ -183,11 +137,11 @@ AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
 
 //===----------------------------------------------------------------------===//
 //
-//          Implementation of AArch64LoopIdiomTransform
+//          Implementation of LoopIdiomVectorize
 //
 //===----------------------------------------------------------------------===//
 
-bool AArch64LoopIdiomTransform::run(Loop *L) {
+bool LoopIdiomVectorize::run(Loop *L) {
   CurLoop = L;
 
   Function &F = *L->getHeader()->getParent();
@@ -211,7 +165,7 @@ bool AArch64LoopIdiomTransform::run(Loop *L) {
   return recognizeByteCompare();
 }
 
-bool AArch64LoopIdiomTransform::recognizeByteCompare() {
+bool LoopIdiomVectorize::recognizeByteCompare() {
   // Currently the transformation only works on scalable vector types, although
   // there is no fundamental reason why it cannot be made to work for fixed
   // width too.
@@ -224,7 +178,7 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
 
   BasicBlock *Header = CurLoop->getHeader();
 
-  // In AArch64LoopIdiomTransform::run we have already checked that the loop
+  // In LoopIdiomVectorize::run we have already checked that the loop
   // has a preheader so we can assume it's in a canonical form.
   if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
     return false;
@@ -242,8 +196,7 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   //   %cmp.not = icmp eq i32 %inc, %n
   //   br i1 %cmp.not, label %while.end, label %while.body
   //
-  auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
-  if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4)
+  if (LoopBlocks[0]->sizeWithoutDebug() > 4)
     return false;
 
   // The second block should contain 7 instructions, e.g.
@@ -257,8 +210,7 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
   //   br i1 %cmp.not.ld, label %while.cond, label %while.end
   //
-  auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
-  if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
+  if (LoopBlocks[1]->sizeWithoutDebug() > 7)
     return false;
 
   // The incoming value to the PHI node from the loop should be an add of 1.
@@ -393,7 +345,107 @@ bool AArch64LoopIdiomTransform::recognizeByteCompare() {
   return true;
 }
 
-Value *AArch64LoopIdiomTransform::expandFindMismatch(
+Value *LoopIdiomVectorize::createMaskedFindMismatch(IRBuilder<> &Builder,
+                                                    GetElementPtrInst *GEPA,
+                                                    GetElementPtrInst *GEPB,
+                                                    Value *ExtStart,
+                                                    Value *ExtEnd) {
+  Type *I64Type = Builder.getInt64Ty();
+  Type *ResType = Builder.getInt32Ty();
+  Type *LoadType = Builder.getInt8Ty();
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // At this point we know two things must be true:
+  //  1. Start <= End
+  //  2. ExtMaxLen <= MinPageSize due to the page checks.
+  // Therefore, we know that we can use a 64-bit induction variable that
+  // starts from 0 -> ExtMaxLen and it will not overflow.
+  ScalableVectorType *PredVTy =
+      ScalableVectorType::get(Builder.getInt1Ty(), 16);
+
+  Value *InitialPred = Builder.CreateIntrinsic(
+      Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
+
+  Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
+  VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
+                             /*HasNUW=*/true, /*HasNSW=*/true);
+
+  Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
+                                            Builder.getInt1(false));
+
+  BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
+  Builder.Insert(JumpToVectorLoop);
+
+  // Set up the first vector loop block by creating the PHIs, doing the vector
+  // loads and comparing the vectors.
+  Builder.SetInsertPoint(VectorLoopStartBlock);
+  PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
+  LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
+  PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
+  VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
+  Type *VectorLoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
+  Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
+
+  Value *VectorLhsGep =
+      Builder.CreateGE...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/94682


More information about the llvm-commits mailing list