[clang] [llvm] [clang-tools-extra] [AArch64] Add an AArch64 pass for loop idiom transformations (PR #72273)

David Sherwood via cfe-commits cfe-commits at lists.llvm.org
Fri Dec 22 09:07:12 PST 2023


================
@@ -0,0 +1,816 @@
+//===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements a pass that recognizes certain loop idioms and
+// transforms them into more optimized versions of the same loop. In cases
+// where this happens, it can be a significant performance win.
+//
+// We currently only recognize one loop that finds the first mismatched byte
+// in an array and returns the index, i.e. something like:
+//
+//  while (++i != n) {
+//    if (a[i] != b[i])
+//      break;
+//  }
+//
+// In this example we can actually vectorize the loop despite the early exit,
+// although the loop vectorizer does not support it. It requires some extra
+// checks to deal with the possibility of faulting loads when crossing page
+// boundaries. However, even with these checks it is still profitable to do the
+// transformation.
+//
+//===----------------------------------------------------------------------===//
+//
+// TODO List:
+//
+// * When optimizing for code size we may want to avoid some transformations.
+// * We can also support the inverse case where we scan for a matching element.
+//
+//===----------------------------------------------------------------------===//
+
+#include "AArch64LoopIdiomTransform.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-loop-idiom-transform"
+
+static cl::opt<bool>
+    DisableAll("disable-aarch64-lit-all", cl::Hidden, cl::init(false),
+               cl::desc("Disable AArch64 Loop Idiom Transform Pass."));
+
+static cl::opt<bool> DisableByteCmp(
+    "disable-aarch64-lit-bytecmp", cl::Hidden, cl::init(false),
+    cl::desc("Proceed with AArch64 Loop Idiom Transform Pass, but do "
+             "not convert byte-compare loop(s)."));
+
+static cl::opt<bool> VerifyLoops(
+    "aarch64-lit-verify", cl::Hidden, cl::init(false),
+    cl::desc("Verify loops generated AArch64 Loop Idiom Transform Pass."));
+
+namespace llvm {
+
+void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &);
+Pass *createAArch64LoopIdiomTransformPass();
+
+} // end namespace llvm
+
+namespace {
+
+class AArch64LoopIdiomTransform {
+  Loop *CurLoop = nullptr;
+  DominatorTree *DT;
+  LoopInfo *LI;
+  const TargetTransformInfo *TTI;
+  const DataLayout *DL;
+
+public:
+  explicit AArch64LoopIdiomTransform(DominatorTree *DT, LoopInfo *LI,
+                                     const TargetTransformInfo *TTI,
+                                     const DataLayout *DL)
+      : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
+
+  bool run(Loop *L);
+
+private:
+  /// \name Countable Loop Idiom Handling
+  /// @{
+
+  bool runOnCountableLoop();
+  bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
+                      SmallVectorImpl<BasicBlock *> &ExitBlocks);
+
+  bool recognizeByteCompare();
+  Value *expandFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA,
+                            GetElementPtrInst *GEPB, Instruction *Index,
+                            Value *Start, Value *MaxLen);
+  void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+                            PHINode *IndPhi, Value *MaxLen, Instruction *Index,
+                            Value *Start, bool IncIdx, BasicBlock *FoundBB,
+                            BasicBlock *EndBB);
+  /// @}
+};
+
+class AArch64LoopIdiomTransformLegacyPass : public LoopPass {
+public:
+  static char ID;
+
+  explicit AArch64LoopIdiomTransformLegacyPass() : LoopPass(ID) {
+    initializeAArch64LoopIdiomTransformLegacyPassPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+  StringRef getPassName() const override {
+    return "Transform AArch64-specific loop idioms";
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<LoopInfoWrapperPass>();
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addRequired<TargetTransformInfoWrapperPass>();
+  }
+
+  bool runOnLoop(Loop *L, LPPassManager &LPM) override;
+};
+
+bool AArch64LoopIdiomTransformLegacyPass::runOnLoop(Loop *L,
+                                                    LPPassManager &LPM) {
+
+  if (skipLoop(L))
+    return false;
+
+  auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+  auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
+      *L->getHeader()->getParent());
+  return AArch64LoopIdiomTransform(
+             DT, LI, &TTI, &L->getHeader()->getModule()->getDataLayout())
+      .run(L);
+}
+
+} // end anonymous namespace
+
+char AArch64LoopIdiomTransformLegacyPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(
+    AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
+    "Transform specific loop idioms into optimized vector forms", false, false)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
+INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_END(
+    AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
+    "Transform specific loop idioms into optimized vector forms", false, false)
+
+Pass *llvm::createAArch64LoopIdiomTransformPass() {
+  return new AArch64LoopIdiomTransformLegacyPass();
+}
+
+PreservedAnalyses
+AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
+                                   LoopStandardAnalysisResults &AR,
+                                   LPMUpdater &) {
+  if (DisableAll)
+    return PreservedAnalyses::all();
+
+  const auto *DL = &L.getHeader()->getModule()->getDataLayout();
+
+  AArch64LoopIdiomTransform LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
+  if (!LIT.run(&L))
+    return PreservedAnalyses::all();
+
+  return PreservedAnalyses::none();
+}
+
+//===----------------------------------------------------------------------===//
+//
+//          Implementation of AArch64LoopIdiomTransform
+//
+//===----------------------------------------------------------------------===//
+
+bool AArch64LoopIdiomTransform::run(Loop *L) {
+  CurLoop = L;
+
+  if (DisableAll)
+    return false;
+
+  // If the loop could not be converted to canonical form, it must have an
+  // indirectbr in it, just give up.
+  if (!L->getLoopPreheader())
+    return false;
+
+  LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F["
+                    << CurLoop->getHeader()->getParent()->getName()
+                    << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
+
+  return recognizeByteCompare();
+}
+
+
+bool AArch64LoopIdiomTransform::recognizeByteCompare() {
+  // Currently the transformation only works on scalable vector types, although
+  // there is no fundamental reason why it cannot be made to work for fixed
+  // width too.
+
+  // We also need to know the minimum page size for the target in order to
+  // generate runtime memory checks to ensure the vector version won't fault.
+  if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
+      DisableByteCmp)
+    return false;
+
+  BasicBlock *Header = CurLoop->getHeader();
+
+  // In AArch64LoopIdiomTransform::run we have already checked that the loop
+  // has a preheader so we can assume it's in a canonical form.
+  if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
+    return false;
+
+  PHINode *PN = dyn_cast<PHINode>(&Header->front());
+  if (!PN || PN->getNumIncomingValues() != 2)
+    return false;
+
+  auto LoopBlocks = CurLoop->getBlocks();
+  // The first block in the loop should contain only 4 instructions, e.g.
+  //
+  //  while.cond:
+  //   %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
+  //   %inc = add i32 %res.phi, 1
+  //   %cmp.not = icmp eq i32 %inc, %n
+  //   br i1 %cmp.not, label %while.end, label %while.body
+  //
+  auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
+  if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4)
+    return false;
+
+  // The second block should contain 7 instructions, e.g.
+  //
+  // while.body:
+  //   %idx = zext i32 %inc to i64
+  //   %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
+  //   %load.a = load i8, ptr %idx.a
+  //   %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
+  //   %load.b = load i8, ptr %idx.b
+  //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
+  //   br i1 %cmp.not.ld, label %while.cond, label %while.end
+  //
+  auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
+  if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
+    return false;
+
+  using namespace PatternMatch;
+
+  // The incoming value to the PHI node from the loop should be an add of 1.
+  Value *StartIdx = nullptr;
+  Instruction *Index = nullptr;
+  if (!CurLoop->contains(PN->getIncomingBlock(0))) {
+    StartIdx = PN->getIncomingValue(0);
+    Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
+  } else {
+    StartIdx = PN->getIncomingValue(1);
+    Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
+  }
+
+  // Limit to 32-bit types for now
+  if (!Index || !Index->getType()->isIntegerTy(32) ||
+      !match(Index, m_c_Add(m_Specific(PN), m_One())))
+    return false;
+
+  // If we match the pattern, PN and Index will be replaced with the result of
+  // the cttz.elts intrinsic. If any other instructions are used outside of
+  // the loop, we cannot replace it.
+  for (BasicBlock *BB : LoopBlocks)
+    for (Instruction &I : *BB)
+      if (&I != PN && &I != Index)
+        for (User *U : I.users())
+          if (!CurLoop->contains(cast<Instruction>(U)))
+            return false;
+
+  // Match the branch instruction for the header
+  ICmpInst::Predicate Pred;
+  Value *MaxLen;
+  BasicBlock *EndBB, *WhileBB;
+  if (!match(Header->getTerminator(),
+             m_Br(m_ICmp(Pred, m_Specific(Index), m_Value(MaxLen)),
+                  m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
+      Pred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(WhileBB))
+    return false;
+
+  // WhileBB should contain the pattern of load & compare instructions. Match
+  // the pattern and find the GEP instructions used by the loads.
+  ICmpInst::Predicate WhilePred;
+  BasicBlock *FoundBB;
+  BasicBlock *TrueBB;
+  Value *LoadA, *LoadB;
+  if (!match(WhileBB->getTerminator(),
+             m_Br(m_ICmp(WhilePred, m_Value(LoadA), m_Value(LoadB)),
+                  m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
+      WhilePred != ICmpInst::Predicate::ICMP_EQ || !CurLoop->contains(TrueBB))
+    return false;
+
+  Value *A, *B;
+  if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
+    return false;
+
+  GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
+  GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
+
+  if (!GEPA || !GEPB)
+    return false;
+
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // Check we are loading i8 values from two loop invariant pointers
+  if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
+      !GEPA->getResultElementType()->isIntegerTy(8) ||
+      !GEPB->getResultElementType()->isIntegerTy(8) ||
+      !cast<LoadInst>(LoadA)->getType()->isIntegerTy(8) ||
+      !cast<LoadInst>(LoadB)->getType()->isIntegerTy(8) || PtrA == PtrB)
+    return false;
+
+  // Check that the index to the GEPs is the index we found earlier
+  if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
+    return false;
+
+  Value *IdxA = GEPA->getOperand(GEPA->getNumIndices());
+  Value *IdxB = GEPB->getOperand(GEPB->getNumIndices());
+  if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index))))
+    return false;
+
+  // Ensure that when the Found and End blocks are identical the PHIs have the
+  // supported format. We don't currently allow cases like this:
+  // while.cond:
+  //   ...
+  //   br i1 %cmp.not, label %while.end, label %while.body
+  //
+  // while.body:
+  //   ...
+  //   br i1 %cmp.not2, label %while.cond, label %while.end
+  //
+  // while.end:
+  //   %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
+  //
+  // Where the incoming values for %final_ptr are unique and from each of the
+  // loop blocks, but not actually defined in the loop. This requires extra
+  // work setting up the byte.compare block, i.e. by introducing a select to
+  // choose the correct value.
+  // TODO: We could add support for this in future.
+  if (FoundBB == EndBB) {
+    for (PHINode &PN : EndBB->phis()) {
+      Value *LastValue = nullptr;
+      for (unsigned I = 0; I < PN.getNumIncomingValues(); I++) {
+        BasicBlock *BB = PN.getIncomingBlock(I);
+        if (CurLoop->contains(BB)) {
+          Value *V = PN.getIncomingValue(I);
+          if (!LastValue)
+            LastValue = V;
+          else if (LastValue != V)
+            return false;
+        }
+      }
+    }
+  }
+
+  LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
+                    << *(EndBB->getParent()) << "\n\n");
+
+  // The index is incremented before the GEP/Load pair so we need to
+  // add 1 to the start value.
+  transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true,
+                       FoundBB, EndBB);
+  return true;
+}
+
+Value *AArch64LoopIdiomTransform::expandFindMismatch(
+    IRBuilder<> &Builder, GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+    Instruction *Index, Value *Start, Value *MaxLen) {
+  Value *PtrA = GEPA->getPointerOperand();
+  Value *PtrB = GEPB->getPointerOperand();
+
+  // Get the arguments and types for the intrinsic.
+  BasicBlock *Preheader = CurLoop->getLoopPreheader();
+  BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
+  LLVMContext &Ctx = PHBranch->getContext();
+  Type *LoadType = Type::getInt8Ty(Ctx);
+  Type *ResType = Builder.getInt32Ty();
+
+  // Split block in the original loop preheader.
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  BasicBlock *EndBlock =
+      SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
+
+  // Create the blocks that we're going to need:
+  //  1. A block for checking the zero-extended length exceeds 0
+  //  2. A block to check that the start and end addresses of a given array
+  //     lie on the same page.
+  //  3. The SVE loop preheader.
+  //  4. The first SVE loop block.
+  //  5. The SVE loop increment block.
+  //  6. A block we can jump to from the SVE loop when a mismatch is found.
+  //  7. The first block of the scalar loop itself, containing PHIs , loads
+  //  and cmp.
+  //  8. A scalar loop increment block to increment the PHIs and go back
+  //  around the loop.
+
+  BasicBlock *MinItCheckBlock = BasicBlock::Create(
+      Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock);
+
+  // Update the terminator added by SplitBlock to branch to the first block
+  Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock);
+
+  BasicBlock *MemCheckBlock = BasicBlock::Create(
+      Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *SVELoopPreheaderBlock = BasicBlock::Create(
+      Ctx, "mismatch_sve_loop_preheader", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *SVELoopStartBlock = BasicBlock::Create(
+      Ctx, "mismatch_sve_loop", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *SVELoopIncBlock = BasicBlock::Create(
+      Ctx, "mismatch_sve_loop_inc", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *SVELoopMismatchBlock = BasicBlock::Create(
+      Ctx, "mismatch_sve_loop_found", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
+      Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *LoopStartBlock =
+      BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock);
+
+  BasicBlock *LoopIncBlock = BasicBlock::Create(
+      Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock);
+
+  DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock},
+                    {DominatorTree::Delete, Preheader, EndBlock}});
+
+  // Update LoopInfo with the new SVE & scalar loops.
+  auto SVELoop = LI->AllocateLoop();
+  auto ScalarLoop = LI->AllocateLoop();
+
+  if (CurLoop->getParentLoop()) {
+    CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI);
+    CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI);
+    CurLoop->getParentLoop()->addBasicBlockToLoop(SVELoopPreheaderBlock, *LI);
+    CurLoop->getParentLoop()->addChildLoop(SVELoop);
+    CurLoop->getParentLoop()->addBasicBlockToLoop(SVELoopMismatchBlock, *LI);
+    CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI);
+    CurLoop->getParentLoop()->addChildLoop(ScalarLoop);
+  } else {
+    LI->addTopLevelLoop(SVELoop);
+    LI->addTopLevelLoop(ScalarLoop);
+  }
+
+  // Add the new basic blocks to their associated loops.
+  SVELoop->addBasicBlockToLoop(SVELoopStartBlock, *LI);
+  SVELoop->addBasicBlockToLoop(SVELoopIncBlock, *LI);
+
+  ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI);
+  ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI);
+
+  // Set up some types and constants that we intend to reuse.
+  Type *I64Type = Builder.getInt64Ty();
+
+  // Check the zero-extended iteration count > 0
+  Builder.SetInsertPoint(MinItCheckBlock);
+  Value *ExtStart = Builder.CreateZExt(Start, I64Type);
+  Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type);
+  // This check doesn't really cost us very much.
+
+  Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen);
+  BranchInst *MinItCheckBr =
+      BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck);
+  MinItCheckBr->setMetadata(
+      LLVMContext::MD_prof,
+      MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1));
+  Builder.Insert(MinItCheckBr);
+
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock},
+       {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}});
+
+  // For each of the arrays, check the start/end addresses are on the same
+  // page.
+  Builder.SetInsertPoint(MemCheckBlock);
+
+  // The early exit in the original loop means that when performing vector
+  // loads we are potentially reading ahead of the early exit. So we could
+  // fault if crossing a page boundary. Therefore, we create runtime memory
+  // checks based on the minimum page size as follows:
+  //   1. Calculate the addresses of the first memory accesses in the loop,
+  //      i.e. LhsStart and RhsStart.
+  //   2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd.
+  //   3. Determine which pages correspond to all the memory accesses, i.e
+  //      LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage.
+  //   4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then
+  //      we know we won't cross any page boundaries in the loop so we can
+  //      enter the vector loop! Otherwise we fall back on the scalar loop.
+  Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart);
+  Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart);
+  Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type);
+  Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type);
+  Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd);
+  Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd);
+  Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type);
+  Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type);
+
+  const uint64_t MinPageSize = TTI->getMinPageSize().value();
+  const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize);
+  Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt);
+  Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt);
+  Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt);
+  Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt);
+  Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage);
+  Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage);
+
+  Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp);
+  BranchInst *CombinedPageCmpCmpBr = BranchInst::Create(
+      LoopPreHeaderBlock, SVELoopPreheaderBlock, CombinedPageCmp);
+  CombinedPageCmpCmpBr->setMetadata(
+      LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext())
+                                .createBranchWeights(10, 90));
+  Builder.Insert(CombinedPageCmpCmpBr);
+
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock},
+       {DominatorTree::Insert, MemCheckBlock, SVELoopPreheaderBlock}});
+
+  // Set up the SVE loop preheader, i.e. calculate initial loop predicate,
+  // zero-extend MaxLen to 64-bits, determine the number of vector elements
+  // processed in each iteration, etc.
+  Builder.SetInsertPoint(SVELoopPreheaderBlock);
+
+  // At this point we know two things must be true:
+  //  1. Start <= End
+  //  2. ExtMaxLen <= MinPageSize due to the page checks.
+  // Therefore, we know that we can use a 64-bit induction variable that
+  // starts from 0 -> ExtMaxLen and it will not overflow.
+  ScalableVectorType *PredVTy =
+      ScalableVectorType::get(Builder.getInt1Ty(), 16);
+
+  Value *InitialPred = Builder.CreateIntrinsic(
+      Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
+
+  Value *VecLen = Builder.CreateIntrinsic(Intrinsic::vscale, {I64Type}, {});
+  VecLen = Builder.CreateMul(VecLen, ConstantInt::get(I64Type, 16), "",
+                             /*HasNUW=*/true, /*HasNSW=*/true);
+
+  Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
+                                            Builder.getInt1(false));
+
+  BranchInst *JumpToSVELoop = BranchInst::Create(SVELoopStartBlock);
+  Builder.Insert(JumpToSVELoop);
+
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, SVELoopPreheaderBlock, SVELoopStartBlock}});
+
+  // Set up the first SVE loop block by creating the PHIs, doing the vector
+  // loads and comparing the vectors.
+  Builder.SetInsertPoint(SVELoopStartBlock);
+  PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_sve_loop_pred");
+  LoopPred->addIncoming(InitialPred, SVELoopPreheaderBlock);
+  PHINode *SVEIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_sve_index");
+  SVEIndexPhi->addIncoming(ExtStart, SVELoopPreheaderBlock);
+  Type *SVELoadType = ScalableVectorType::get(Builder.getInt8Ty(), 16);
+  Value *GepOffset = SVEIndexPhi;
+  Value *Passthru = ConstantInt::getNullValue(SVELoadType);
+
+  Value *SVELhsGep = Builder.CreateGEP(LoadType, PtrA, GepOffset);
+  if (GEPA->isInBounds())
+    cast<GetElementPtrInst>(SVELhsGep)->setIsInBounds(true);
+  Value *SVELhsLoad = Builder.CreateMaskedLoad(SVELoadType, SVELhsGep, Align(1),
+                                               LoopPred, Passthru);
+
+  Value *SVERhsGep = Builder.CreateGEP(LoadType, PtrB, GepOffset);
+  if (GEPB->isInBounds())
+    cast<GetElementPtrInst>(SVERhsGep)->setIsInBounds(true);
+  Value *SVERhsLoad = Builder.CreateMaskedLoad(SVELoadType, SVERhsGep, Align(1),
+                                               LoopPred, Passthru);
+
+  Value *SVEMatchCmp = Builder.CreateICmpNE(SVELhsLoad, SVERhsLoad);
+  SVEMatchCmp = Builder.CreateSelect(LoopPred, SVEMatchCmp, PFalse);
+  Value *SVEMatchHasActiveLanes = Builder.CreateOrReduce(SVEMatchCmp);
+  BranchInst *SVEEarlyExit = BranchInst::Create(
+      SVELoopMismatchBlock, SVELoopIncBlock, SVEMatchHasActiveLanes);
+  Builder.Insert(SVEEarlyExit);
+
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, SVELoopStartBlock, SVELoopMismatchBlock},
+       {DominatorTree::Insert, SVELoopStartBlock, SVELoopIncBlock}});
+
+  // Increment the index counter and calculate the predicate for the next
+  // iteration of the loop. We branch back to the start of the loop if there
+  // is at least one active lane.
+  Builder.SetInsertPoint(SVELoopIncBlock);
+  Value *NewSVEIndexPhi = Builder.CreateAdd(SVEIndexPhi, VecLen, "",
+                                            /*HasNUW=*/true, /*HasNSW=*/true);
+  SVEIndexPhi->addIncoming(NewSVEIndexPhi, SVELoopIncBlock);
+  Value *NewPred =
+      Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
+                              {PredVTy, I64Type}, {NewSVEIndexPhi, ExtEnd});
+  LoopPred->addIncoming(NewPred, SVELoopIncBlock);
+
+  Value *PredHasActiveLanes =
+      Builder.CreateExtractElement(NewPred, uint64_t(0));
+  BranchInst *SVELoopBranchBack =
+      BranchInst::Create(SVELoopStartBlock, EndBlock, PredHasActiveLanes);
+  Builder.Insert(SVELoopBranchBack);
+
+  DTU.applyUpdates({{DominatorTree::Insert, SVELoopIncBlock, SVELoopStartBlock},
+                    {DominatorTree::Insert, SVELoopIncBlock, EndBlock}});
+
+  // If we found a mismatch then we need to calculate which lane in the vector
+  // had a mismatch and add that on to the current loop index.
+  Builder.SetInsertPoint(SVELoopMismatchBlock);
+  PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_sve_found_pred");
+  FoundPred->addIncoming(SVEMatchCmp, SVELoopStartBlock);
+  PHINode *LastLoopPred =
+      Builder.CreatePHI(PredVTy, 1, "mismatch_sve_last_loop_pred");
+  LastLoopPred->addIncoming(LoopPred, SVELoopStartBlock);
+  PHINode *SVEFoundIndex =
+      Builder.CreatePHI(I64Type, 1, "mismatch_sve_found_index");
+  SVEFoundIndex->addIncoming(SVEIndexPhi, SVELoopStartBlock);
+
+  Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
+  Value *Ctz = Builder.CreateIntrinsic(
+      Intrinsic::experimental_cttz_elts, {ResType, PredMatchCmp->getType()},
+      {PredMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(true)});
+  Ctz = Builder.CreateZExt(Ctz, I64Type);
+  Value *SVELoopRes64 = Builder.CreateAdd(SVEFoundIndex, Ctz, "",
+                                          /*HasNUW=*/true, /*HasNSW=*/true);
+  Value *SVELoopRes = Builder.CreateTrunc(SVELoopRes64, ResType);
+
+  Builder.Insert(BranchInst::Create(EndBlock));
+
+  DTU.applyUpdates({{DominatorTree::Insert, SVELoopMismatchBlock, EndBlock}});
+
+  // Generate code for scalar loop.
+  Builder.SetInsertPoint(LoopPreHeaderBlock);
+  Builder.Insert(BranchInst::Create(LoopStartBlock));
+
+  DTU.applyUpdates(
+      {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}});
+
+  Builder.SetInsertPoint(LoopStartBlock);
+  PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index");
+  IndexPhi->addIncoming(Start, LoopPreHeaderBlock);
+
+  // Otherwise compare the values
+  // Load bytes from each array and compare them.
+  GepOffset = Builder.CreateZExt(IndexPhi, I64Type);
+
+  Value *LhsGep = Builder.CreateGEP(LoadType, PtrA, GepOffset);
+  if (GEPA->isInBounds())
+    cast<GetElementPtrInst>(LhsGep)->setIsInBounds(true);
+  Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep);
+
+  Value *RhsGep = Builder.CreateGEP(LoadType, PtrB, GepOffset);
+  if (GEPB->isInBounds())
+    cast<GetElementPtrInst>(RhsGep)->setIsInBounds(true);
+  Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep);
+
+  Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
+  // If we have a mismatch then exit the loop ...
+  BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp);
+  Builder.Insert(MatchCmpBr);
+
+  DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock},
+                    {DominatorTree::Insert, LoopStartBlock, EndBlock}});
+
+  // Have we reached the maximum permitted length for the loop?
+  Builder.SetInsertPoint(LoopIncBlock);
+  Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "",
+                                    /*HasNUW=*/Index->hasNoUnsignedWrap(),
+                                    /*HasNSW=*/Index->hasNoSignedWrap());
+  IndexPhi->addIncoming(PhiInc, LoopIncBlock);
+  Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen);
+  BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp);
+  Builder.Insert(IVCmpBr);
+
+  DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock},
+                    {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}});
+
+  // In the end block we need to insert a PHI node to deal with three cases:
+  //  1. We didn't find a mismatch in the scalar loop, so we return MaxLen.
+  //  2. We exitted the scalar loop early due to a mismatch and need to return
+  //  the index that we found.
+  //  3. We didn't find a mismatch in the SVE loop, so we return MaxLen.
+  //  4. We exitted the SVE loop early due to a mismatch and need to return
+  //  the index that we found.
+  Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt());
+  PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result");
+  ResPhi->addIncoming(MaxLen, LoopIncBlock);
+  ResPhi->addIncoming(IndexPhi, LoopStartBlock);
+  ResPhi->addIncoming(MaxLen, SVELoopIncBlock);
+  ResPhi->addIncoming(SVELoopRes, SVELoopMismatchBlock);
+
+  Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType);
+
+  if (VerifyLoops) {
+    ScalarLoop->verifyLoop();
+    SVELoop->verifyLoop();
+    if (!SVELoop->isRecursivelyLCSSAForm(*DT, *LI))
+      report_fatal_error("Loops must remain in LCSSA form!");
+    if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI))
+      report_fatal_error("Loops must remain in LCSSA form!");
+  }
+
+  return FinalRes;
+}
+
+void AArch64LoopIdiomTransform::transformByteCompare(
+    GetElementPtrInst *GEPA, GetElementPtrInst *GEPB, PHINode *IndPhi,
+    Value *MaxLen, Instruction *Index, Value *Start, bool IncIdx,
+    BasicBlock *FoundBB, BasicBlock *EndBB) {
+
+  // Insert the byte compare code at the end of the preheader block
+  BasicBlock *Preheader = CurLoop->getLoopPreheader();
+  BasicBlock *Header = CurLoop->getHeader();
+  BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
+  IRBuilder<> Builder(PHBranch);
+  Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
+
+  // Increment the pointer if this was done before the loads in the loop.
+  if (IncIdx)
+    Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1));
+
+  Value *ByteCmpRes =
+      expandFindMismatch(Builder, GEPA, GEPB, Index, Start, MaxLen);
+
+  // Replaces uses of index & induction Phi with intrinsic (we already
+  // checked that the the first instruction of Header is the Phi above).
+  IndPhi->replaceAllUsesWith(ByteCmpRes);
+  Index->replaceAllUsesWith(ByteCmpRes);
----------------
david-arm wrote:

For this version of the patch there shouldn't be more than one use of IndPhi and it should be inside the loop only. We should only ever see outside uses of Index. For now, I've added more checks to recognizeByteCompare that IndPhi has one use, and added an assert here.

https://github.com/llvm/llvm-project/pull/72273


More information about the cfe-commits mailing list