[llvm] [AArch64] Add an AArch64 pass for loop idiom transformations (PR #72273)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 14 07:49:20 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: David Sherwood (david-arm)

<details>
<summary>Changes</summary>

We have added a new pass that looks for loops such as the following:

  while (i != max_len)
      if (a[i] != b[i])
          break;

  ... use index i ...

Although similar to a memcmp, this is slightly different because instead of returning the difference between the values of the first non-matching pair of bytes, it returns the index of the first mismatch. As such, we are not able to lower this to a memcmp call.

The new pass can now spot such idioms and transform them into a specialised predicated loop that gives a significant performance improvement for AArch64. It is intended as a stop-gap solution until this can be handled by the vectoriser, which doesn't currently deal with early exits.

This specialised loop makes use of a generic intrinsic that counts the trailing zero elements in a predicate vector. This was added in https://reviews.llvm.org/D159283 and for SVE we end up with brkb & incp instructions.

Although we have added this pass only for AArch64, it was written in a generic way so that in theory it could be used by other targets. Currently the pass requires scalable vector support and needs to know the minimum page size for the target, however it's possible to make it work for fixed-width vectors too. Also, the llvm.experimental.cttz.elts intrinsic used by the pass has generic lowering, but can be made efficient for targets with instructions similar to SVE's brkb, cntp and incp.

Original version of patch was posted on Phabricator:

 https://reviews.llvm.org/D158291

Patch co-authored by Kerry McLaughlin (@<!-- -->kmclaughlin-arm) and David Sherwood (@<!-- -->david-arm)

See the original discussion on Discourse:
https://discourse.llvm.org/t/aarch64-target-specific-loop-idiom-recognition/72383

---

Patch is 138.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72273.diff


12 Files Affected:

- (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+8) 
- (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+2) 
- (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4) 
- (modified) llvm/lib/Target/AArch64/AArch64.h (+1) 
- (added) llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp (+726) 
- (added) llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.h (+25) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetMachine.cpp (+10) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetMachine.h (+3) 
- (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2) 
- (modified) llvm/lib/Target/AArch64/CMakeLists.txt (+1) 
- (added) llvm/test/Transforms/LoopIdiom/AArch64/byte-compare-index.ll (+1640) 
- (modified) llvm/utils/gn/secondary/llvm/lib/Target/AArch64/BUILD.gn (+1) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index c18e0acdb759a8d..10d178a73b0fcdd 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1155,6 +1155,9 @@ class TargetTransformInfo {
   /// \return The associativity of the cache level, if available.
   std::optional<unsigned> getCacheAssociativity(CacheLevel Level) const;
 
+  /// \return The minimum architectural page size for the target.
+  std::optional<unsigned> getMinPageSize() const;
+
   /// \return How much before a load we should place the prefetch
   /// instruction.  This is currently measured in number of
   /// instructions.
@@ -1889,6 +1892,7 @@ class TargetTransformInfo::Concept {
   virtual std::optional<unsigned> getCacheSize(CacheLevel Level) const = 0;
   virtual std::optional<unsigned> getCacheAssociativity(CacheLevel Level)
       const = 0;
+  virtual std::optional<unsigned> getMinPageSize() const = 0;
 
   /// \return How much before a load we should place the prefetch
   /// instruction.  This is currently measured in number of
@@ -2475,6 +2479,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
     return Impl.getCacheAssociativity(Level);
   }
 
+  std::optional<unsigned> getMinPageSize() const override {
+    return Impl.getMinPageSize();
+  }
+
   /// Return the preferred prefetch distance in terms of instructions.
   ///
   unsigned getPrefetchDistance() const override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 2ccf57c22234f9a..13030cb9fe46825 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -494,6 +494,8 @@ class TargetTransformInfoImplBase {
     llvm_unreachable("Unknown TargetTransformInfo::CacheLevel");
   }
 
+  std::optional<unsigned> getMinPageSize() const { return {}; }
+
   unsigned getPrefetchDistance() const { return 0; }
   unsigned getMinPrefetchStride(unsigned NumMemAccesses,
                                 unsigned NumStridedMemAccesses,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index caa9b17ae695e49..dfe8d004bb97899 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -753,6 +753,10 @@ TargetTransformInfo::getCacheAssociativity(CacheLevel Level) const {
   return TTIImpl->getCacheAssociativity(Level);
 }
 
+std::optional<unsigned> TargetTransformInfo::getMinPageSize() const {
+  return TTIImpl->getMinPageSize();
+}
+
 unsigned TargetTransformInfo::getPrefetchDistance() const {
   return TTIImpl->getPrefetchDistance();
 }
diff --git a/llvm/lib/Target/AArch64/AArch64.h b/llvm/lib/Target/AArch64/AArch64.h
index 901769c54b6ef59..d20ef63a72e8f62 100644
--- a/llvm/lib/Target/AArch64/AArch64.h
+++ b/llvm/lib/Target/AArch64/AArch64.h
@@ -88,6 +88,7 @@ void initializeAArch64DeadRegisterDefinitionsPass(PassRegistry&);
 void initializeAArch64ExpandPseudoPass(PassRegistry &);
 void initializeAArch64GlobalsTaggingPass(PassRegistry &);
 void initializeAArch64LoadStoreOptPass(PassRegistry&);
+void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &);
 void initializeAArch64LowerHomogeneousPrologEpilogPass(PassRegistry &);
 void initializeAArch64MIPeepholeOptPass(PassRegistry &);
 void initializeAArch64O0PreLegalizerCombinerPass(PassRegistry &);
diff --git a/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp b/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
new file mode 100644
index 000000000000000..c4d2529a3e53e7d
--- /dev/null
+++ b/llvm/lib/Target/AArch64/AArch64LoopIdiomTransform.cpp
@@ -0,0 +1,726 @@
+
+//===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "AArch64LoopIdiomTransform.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-lit"
+
+static cl::opt<bool>
+    DisableAll("disable-aarch64-lit-all", cl::Hidden, cl::init(false),
+               cl::desc("Disable AArch64 Loop Idiom Transform Pass."));
+
+static cl::opt<bool> DisableByteCmp(
+    "disable-aarch64-lit-bytecmp", cl::Hidden, cl::init(false),
+    cl::desc("Proceed with AArch64 Loop Idiom Transform Pass, but do "
+             "not convert byte-compare loop(s)."));
+
+namespace llvm {
+
+void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &);
+Pass *createAArch64LoopIdiomTransformPass();
+
+} // end namespace llvm
+
+namespace {
+
+class AArch64LoopIdiomTransform {
+  Loop *CurLoop = nullptr;
+  DominatorTree *DT;
+  LoopInfo *LI;
+  const TargetTransformInfo *TTI;
+  const DataLayout *DL;
+
+public:
+  explicit AArch64LoopIdiomTransform(DominatorTree *DT, LoopInfo *LI,
+                                     const TargetTransformInfo *TTI,
+                                     const DataLayout *DL)
+      : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
+
+  bool run(Loop *L);
+
+private:
+  /// \name Countable Loop Idiom Handling
+  /// @{
+
+  bool runOnCountableLoop();
+  bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
+                      SmallVectorImpl<BasicBlock *> &ExitBlocks);
+
+  bool recognizeByteCompare();
+  Value *expandFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA,
+                            GetElementPtrInst *GEPB, Value *Start,
+                            Value *MaxLen);
+  void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+                            Value *MaxLen, Value *Index, Value *Start,
+                            bool IncIdx, BasicBlock *FoundBB,
+                            BasicBlock *EndBB);
+  /// @}
+};
+
+class AArch64LoopIdiomTransformLegacyPass : public LoopPass {
+public:
+  static char ID;
+
+  explicit AArch64LoopIdiomTransformLegacyPass() : LoopPass(ID) {
+    initializeAArch64LoopIdiomTransformLegacyPassPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+  StringRef getPassName() const override {
+    return "Recognize AArch64-specific loop idioms";
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<LoopInfoWrapperPass>();
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addRequired<TargetTransformInfoWrapperPass>();
+  }
+
+  bool runOnLoop(Loop *L, LPPassManager &LPM) override;
+};
+
+bool AArch64LoopIdiomTransformLegacyPass::runOnLoop(Loop *L,
+                                                    LPPassManager &LPM) {
+
+  if (skipLoop(L))
+    return false;
+
+  auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
+      *L->getHeader()->getParent());
+  return AArch64LoopIdiomTransform(
+             DT, LI, &TTI, &L->getHeader()->getModule()->getDataLayout())
+      .run(L);
+}
+
+} // end anonymous namespace
+
+char AArch64LoopIdiomTransformLegacyPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(
+    AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
+    "Transform specific loop idioms into optimised vector forms", false, false)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
+INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_END(
+    AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
+    "Transform specific loop idioms into optimised vector forms", false, false)
+
+Pass *llvm::createAArch64LoopIdiomTransformPass() {
+  return new AArch64LoopIdiomTransformLegacyPass();
+}
+
+PreservedAnalyses
+AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
+                                   LoopStandardAnalysisResults &AR,
+                                   LPMUpdater &) {
+  if (DisableAll)
+    return PreservedAnalyses::all();
+
+  const auto *DL = &L.getHeader()->getModule()->getDataLayout();
+
+  AArch64LoopIdiomTransform LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
+  if (!LIT.run(&L))
+    return PreservedAnalyses::all();
+
+  return PreservedAnalyses::none();
+}
+
+//===----------------------------------------------------------------------===//
+//
+//          Implementation of AArch64LoopIdiomTransform
+//
+//===----------------------------------------------------------------------===//
+
+bool AArch64LoopIdiomTransform::run(Loop *L) {
+  CurLoop = L;
+
+  if (DisableAll)
+    return false;
+
+  // If the loop could not be converted to canonical form, it must have an
+  // indirectbr in it, just give up.
+  if (!L->getLoopPreheader())
+    return false;
+
+  LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F["
+                    << CurLoop->getHeader()->getParent()->getName()
+                    << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
+
+  return recognizeByteCompare();
+}
+
+/// Match loop-invariant value.
+template <typename SubPattern_t> struct match_LoopInvariant {
+  SubPattern_t SubPattern;
+  const Loop *L;
+
+  match_LoopInvariant(const SubPattern_t &SP, const Loop *L)
+      : SubPattern(SP), L(L) {}
+
+  template <typename ITy> bool match(ITy *V) {
+    return L->isLoopInvariant(V) && SubPattern.match(V);
+  }
+};
+
+/// Matches if the value is loop-invariant.
+template <typename Ty>
+inline match_LoopInvariant<Ty> m_LoopInvariant(const Ty &M, const Loop *L) {
+  return match_LoopInvariant<Ty>(M, L);
+}
+
+bool AArch64LoopIdiomTransform::recognizeByteCompare() {
+  if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
+      DisableByteCmp)
+    return false;
+
+  BasicBlock *Header = CurLoop->getHeader();
+  BasicBlock *PH = CurLoop->getLoopPreheader();
+
+  // In AArch64LoopIdiomTransform::run we have already checked that the loop
+  // has a preheader so we can assume it's in a canonical form.
+  auto *EntryBI = cast<BranchInst>(PH->getTerminator());
+
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
+    return false;
+
+  PHINode *PN = dyn_cast<PHINode>(&Header->front());
+  if (!PN || PN->getNumIncomingValues() != 2)
+    return false;
+
+  auto LoopBlocks = CurLoop->getBlocks();
+  // The first block in the loop should contain only 4 instructions, e.g.
+  //
+  //  while.cond:
+  //   %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
+  //   %inc = add i32 %res.phi, 1
+  //   %cmp.not = icmp eq i32 %inc, %n
+  //   br i1 %cmp.not, label %while.end, label %while.body
+  //
+  auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
+  if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4)
+    return false;
+
+  // The second block should contain 7 instructions, e.g.
+  //
+  // while.body:
+  //   %idx = zext i32 %inc to i64
+  //   %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
+  //   %load.a = load i8, ptr %idx.a
+  //   %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
+  //   %load.b = load i8, ptr %idx.b
+  //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
+  //   br i1 %cmp.not.ld, label %while.cond, label %while.end
+  //
+  auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
+  if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
+    return false;
+
+  using namespace PatternMatch;
+
+  // The incoming value to the PHI node from the loop should be an add of 1.
+  Instruction *Index = nullptr;
+  Value *StartIdx = nullptr;
+  for (BasicBlock *BB : PN->blocks()) {
+    if (!CurLoop->contains(BB)) {
+      StartIdx = PN->getIncomingValueForBlock(BB);
+      continue;
+    }
+    Index = dyn_cast<Instruction>(PN->getIncomingValueForBlock(BB));
+    // Limit to 32-bit types for now
+    if (!Index || !Index->getType()->isIntegerTy(32) ||
+        !match(Index, m_c_Add(m_Specific(PN), m_One())))
+      return false;
+  }
+
+  // If we match the pattern, PN and Index will be replaced with the result of
+  // the cttz.elts intrinsic. If any other instructions are used outside of
+  // the loop, we cannot replace it.
+  for (BasicBlock *BB : LoopBlocks)
+    for (Instruction &I : *BB)
+      if (&I != PN && &I != Index)
+        for (User *U : I.users()) {
+          auto UI = cast<Instruction>(U);
+          if (!CurLoop->contains(UI))
+            return false;
+        }
+
+  // Don't replace the loop if the add has a wrap flag.
+  if (Index->hasNoSignedWrap() || Index->hasNoUnsignedWrap())
+    return false;
+
+  // Match the branch instruction for the header
+  ICmpInst::Predicate Pred;
+  Value *MaxLen;
+  BasicBlock *EndBB, *WhileBB;
+  if (!match(Header->getTerminator(),
+             m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
+                  m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
+      Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB))
+    return false;
+
+  // WhileBB should contain the pattern of load & compare instructions. Match
+  // the pattern and find the GEP instructions used by the loads.
+  ICmpInst::Predicate WhilePred;
+  BasicBlock *FoundBB;
+  BasicBlock *TrueBB;
+  Value *LoadA, *LoadB;
+  if (!match(WhileBB->getTerminator(),
+             m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)),
+                  m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
+      WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB))
+    return false;
+
+  Value *A, *B;
+  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+    return false;
+
+  GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
+  GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
+
+  if (!GEPA || !GEPB)
+    return false;
+
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // Check we are loading i8 values from two loop invariant pointers
+  if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
+      !GEPA->getResultElementType()->isIntegerTy(8) ||
+      !GEPB->getResultElementType()->isIntegerTy(8) ||
+      !cast<LoadInst>(LoadA)->getType()->isIntegerTy(8) ||
+      !cast<LoadInst>(LoadB)->getType()->isIntegerTy(8) || PtrA == PtrB)
+    return false;
+
+  // Check that the index to the GEPs is the index we found earlier
+  if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
+    return false;
+
+  Value *IdxA = GEPA->getOperand(GEPA->getNumIndices());
+  Value *IdxB = GEPB->getOperand(GEPB->getNumIndices());
+  if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index))))
+    return false;
+
+  LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
+                    << *(EndBB->getParent()) << "\n\n");
+
+  // The index is incremented before the GEP/Load pair so we need to
+  // add 1 to the start value.
+  transformByteCompare(GEPA, GEPB, MaxLen, Index, StartIdx, /*IncIdx=*/true, FoundBB,
+                       EndBB);
+  return true;
+}
+
+Value *AArch64LoopIdiomTransform::expandFindMismatch(IRBuilder<> &Builder,
+                                                     GetElementPtrInst *GEPA,
+                                                     GetElementPtrInst *GEPB,
+                                                     Value *Start,
+                                                     Value *MaxLen) {
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // Get the arguments and types for the intrinsic.
+  BasicBlock *Preheader = CurLoop->getLoopPreheader();
+  BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
+  LLVMContext &Ctx = PHBranch->getContext();
+  Type *LoadType = Type::getInt8Ty(Ctx);
+  Type *ResType = Builder.getInt32Ty();
+
+  // Split block in the original loop preheader.
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  BasicBlock *EndBlock =
+      SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
+
+  // Create the blocks that we're going to need:
+  //  1. A block for checking the zero-extended length exceeds 0
+  //  2. A block to check that the start and end addresses of a given array
+  //     lie on the same page.
+  //  3. The SVE loop preheader.
+  //  4. The first SVE loop block.
+  //  5. The SVE loop increment block.
+  //  6. A block we can jump to from the SVE loop when a mismatch is found.
+  //  7. The first block of the scalar loop itself, containing PHIs , loads
+  //  and cmp.
+  //  8. A scalar loop increment block to increment the PHIs and go back
+  //  around the loop.
+
+  BasicBlock *MinItCheckBlock = BasicBlock::Create(
+      Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock);
+
+  // Update the terminator added by SplitBlock to branch to the first block
+  Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock);
+
+  BasicBlock *MemCheckBlock = BasicBlock::Create(
+      Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *SVELoopPreheaderBlock = BasicBlock::Create(
+      Ctx, "mismatch_sve_loop_preheader", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *SVELoopStartBlock = BasicBlock::Create(
+      Ctx, "mismatch_sve_loop", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *SVELoopIncBlock = BasicBlock::Create(
+      Ctx, "mismatch_sve_loop_inc", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *SVELoopMismatchBlock = BasicBlock::Create(
+      Ctx, "mismatch_sve_loop_found", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
+      Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *LoopStartBlock =
+      BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *LoopIncBlock = BasicBlock::Create(
+      Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock);
+
+  DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock},
+                    {DominatorTree::Delete, Preheader, EndBlock}});
+
+  // Update LoopInfo with the new SVE & scalar loops.
+  auto SVELoop = LI->AllocateLoop();
+  auto ScalarLoop = LI->AllocateLoop();
+  if (CurLoop->getParentLoop()) {
+    CurLoop->getParentLoop()->addChildLoop(SVELoop);
+    CurLoop->getParentLoop()->addChildLoop(ScalarLoop);
+  } else {
+    LI->addTopLevelLoop(SVELoop);
+    LI->addTopLevelLoop(ScalarLoop);
+  }
+
+  // Add the new basic blocks to their associated loops.
+  SVELoop->addBasicBlockToLoop(MinItCheckBlock, *LI);
+  SVELoop->addBasicBlockToLoop(MemCheckBlock, *LI);
+  SVELoop->addBasicBlockToLoop(SVELoopPreheaderBlock, *LI);
+  SVELoop->addBasicBlockToLoop(SVELoopStartBlock, *LI);
+  SVELoop->addBasicBlockToLoop(SVELoopIncBlock, *LI);
+  SVELoop->addBasicBlockToLoop(SVELoopMismatchBlock, *LI);
+
+  ScalarLoop->addBasicBlockToLoop(LoopPreHeaderBlock, *LI);
+  ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI);
+  ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI);
+
+  // Set up some types and constants that we intend to reuse.
+  Type *I64Type = Builder.getInt64Ty();
+
+  // Check the zero-extended iteration count > 0
+  Builder.SetInsertPoint(MinItCheckBlock);
+  Value *ExtStart = Builder.CreateZExt(Start, I64Type);
+  Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type);
+  // This check doesn't really cost us very much.
+
+  Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen);
+  BranchInst *MinItCheckBr =
+      BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck);
+  MinItCheckBr->setMetadata(
+      LLVMContext::MD_prof,
+      MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1));
+  Builder.Insert(MinItCheckBr);
+
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock},
+       {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}});
+
+  // For each of the arrays, check the start/end addresses are...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/72273


More information about the llvm-commits mailing list