[llvm] [AArch64] Add an AArch64 pass for loop idiom transformations (PR #72273)
Antonio Frighetto via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 30 07:58:17 PST 2023
================
@@ -0,0 +1,768 @@
+//===- AArch64LoopIdiomTransform.cpp - Loop idiom recognition -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass that recognises certain loop idioms and
+// transforms them into more optimised versions of the same loop. In cases
+// where this happens, it can be a significant performance win.
+//
+// We currently only recognise one loop that finds the first mismatched byte
+// in an array and returns the index, i.e. something like:
+//
+// while (++i != n) {
+// if (a[i] != b[i])
+// break;
+// }
+//
+// In this example we can actually vectorise the loop despite the early exit,
+// although the loop vectorizer does not support it. It requires some extra
+// checks to deal with the possibility of faulting loads when crossing page
+// boundaries. However, even with these checks it is still profitable to do the
+// transformation.
+//
+//===----------------------------------------------------------------------===//
+//
+// TODO List:
+//
+// * When optimising for code size we may want to avoid some transformations.
+// * We can also support the inverse case where we scan for a matching element.
+//
+//===----------------------------------------------------------------------===//
+
+#include "AArch64LoopIdiomTransform.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "aarch64-loop-idiom-transform"
+
+static cl::opt<bool>
+ DisableAll("disable-aarch64-lit-all", cl::Hidden, cl::init(false),
+ cl::desc("Disable AArch64 Loop Idiom Transform Pass."));
+
+static cl::opt<bool> DisableByteCmp(
+ "disable-aarch64-lit-bytecmp", cl::Hidden, cl::init(false),
+ cl::desc("Proceed with AArch64 Loop Idiom Transform Pass, but do "
+ "not convert byte-compare loop(s)."));
+
+namespace llvm {
+
+void initializeAArch64LoopIdiomTransformLegacyPassPass(PassRegistry &);
+Pass *createAArch64LoopIdiomTransformPass();
+
+} // end namespace llvm
+
+namespace {
+
+class AArch64LoopIdiomTransform {
+ Loop *CurLoop = nullptr;
+ DominatorTree *DT;
+ LoopInfo *LI;
+ const TargetTransformInfo *TTI;
+ const DataLayout *DL;
+
+public:
+ explicit AArch64LoopIdiomTransform(DominatorTree *DT, LoopInfo *LI,
+ const TargetTransformInfo *TTI,
+ const DataLayout *DL)
+ : DT(DT), LI(LI), TTI(TTI), DL(DL) {}
+
+ bool run(Loop *L);
+
+private:
+ /// \name Countable Loop Idiom Handling
+ /// @{
+
+ bool runOnCountableLoop();
+ bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
+ SmallVectorImpl<BasicBlock *> &ExitBlocks);
+
+ bool recognizeByteCompare();
+ Value *expandFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA,
+ GetElementPtrInst *GEPB, Value *Start,
+ Value *MaxLen);
+ void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+ Value *MaxLen, Value *Index, Value *Start,
+ bool IncIdx, BasicBlock *FoundBB,
+ BasicBlock *EndBB);
+ /// @}
+};
+
+class AArch64LoopIdiomTransformLegacyPass : public LoopPass {
+public:
+ static char ID;
+
+ explicit AArch64LoopIdiomTransformLegacyPass() : LoopPass(ID) {
+ initializeAArch64LoopIdiomTransformLegacyPassPass(
+ *PassRegistry::getPassRegistry());
+ }
+
+ StringRef getPassName() const override {
+ return "Recognize AArch64-specific loop idioms";
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<TargetTransformInfoWrapperPass>();
+ }
+
+ bool runOnLoop(Loop *L, LPPassManager &LPM) override;
+};
+
+bool AArch64LoopIdiomTransformLegacyPass::runOnLoop(Loop *L,
+ LPPassManager &LPM) {
+
+ if (skipLoop(L))
+ return false;
+
+ auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
+ *L->getHeader()->getParent());
+ return AArch64LoopIdiomTransform(
+ DT, LI, &TTI, &L->getHeader()->getModule()->getDataLayout())
+ .run(L);
+}
+
+} // end anonymous namespace
+
+char AArch64LoopIdiomTransformLegacyPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(
+ AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
+ "Transform specific loop idioms into optimised vector forms", false, false)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
+INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_END(
+ AArch64LoopIdiomTransformLegacyPass, "aarch64-lit",
+ "Transform specific loop idioms into optimised vector forms", false, false)
+
+Pass *llvm::createAArch64LoopIdiomTransformPass() {
+ return new AArch64LoopIdiomTransformLegacyPass();
+}
+
+PreservedAnalyses
+AArch64LoopIdiomTransformPass::run(Loop &L, LoopAnalysisManager &AM,
+ LoopStandardAnalysisResults &AR,
+ LPMUpdater &) {
+ if (DisableAll)
+ return PreservedAnalyses::all();
+
+ const auto *DL = &L.getHeader()->getModule()->getDataLayout();
+
+ AArch64LoopIdiomTransform LIT(&AR.DT, &AR.LI, &AR.TTI, DL);
+ if (!LIT.run(&L))
+ return PreservedAnalyses::all();
+
+ return PreservedAnalyses::none();
+}
+
+//===----------------------------------------------------------------------===//
+//
+// Implementation of AArch64LoopIdiomTransform
+//
+//===----------------------------------------------------------------------===//
+
+bool AArch64LoopIdiomTransform::run(Loop *L) {
+ CurLoop = L;
+
+ if (DisableAll)
+ return false;
+
+ // If the loop could not be converted to canonical form, it must have an
+ // indirectbr in it, just give up.
+ if (!L->getLoopPreheader())
+ return false;
+
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F["
+ << CurLoop->getHeader()->getParent()->getName()
+ << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
+
+ return recognizeByteCompare();
+}
+
+/// Match loop-invariant value.
+template <typename SubPattern_t> struct match_LoopInvariant {
+ SubPattern_t SubPattern;
+ const Loop *L;
+
+ match_LoopInvariant(const SubPattern_t &SP, const Loop *L)
+ : SubPattern(SP), L(L) {}
+
+ template <typename ITy> bool match(ITy *V) {
+ return L->isLoopInvariant(V) && SubPattern.match(V);
+ }
+};
+
+/// Matches if the value is loop-invariant.
+template <typename Ty>
+inline match_LoopInvariant<Ty> m_LoopInvariant(const Ty &M, const Loop *L) {
+ return match_LoopInvariant<Ty>(M, L);
+}
+
+bool AArch64LoopIdiomTransform::recognizeByteCompare() {
+ // Currently the transformation only works on scalable vector types, although
+ // there is no fundamental reason why it cannot be made to work for fixed
+ // width too.
+
+ // We also need to know the minimum page size for the target in order to
+ // generate runtime memory checks to ensure the vector version won't fault.
+ if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
+ DisableByteCmp)
+ return false;
+
+ BasicBlock *Header = CurLoop->getHeader();
+ BasicBlock *PH = CurLoop->getLoopPreheader();
+
+ // In AArch64LoopIdiomTransform::run we have already checked that the loop
+ // has a preheader so we can assume it's in a canonical form.
+ auto *EntryBI = cast<BranchInst>(PH->getTerminator());
+
+ if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
+ return false;
+
+ PHINode *PN = dyn_cast<PHINode>(&Header->front());
+ if (!PN || PN->getNumIncomingValues() != 2)
+ return false;
+
+ auto LoopBlocks = CurLoop->getBlocks();
+ // The first block in the loop should contain only 4 instructions, e.g.
+ //
+ // while.cond:
+ // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
+ // %inc = add i32 %res.phi, 1
+ // %cmp.not = icmp eq i32 %inc, %n
+ // br i1 %cmp.not, label %while.end, label %while.body
+ //
+ auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
+ if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) > 4)
+ return false;
+
+ // The second block should contain 7 instructions, e.g.
+ //
+ // while.body:
+ // %idx = zext i32 %inc to i64
+ // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
+ // %load.a = load i8, ptr %idx.a
+ // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
+ // %load.b = load i8, ptr %idx.b
+ // %cmp.not.ld = icmp eq i8 %load.a, %load.b
+ // br i1 %cmp.not.ld, label %while.cond, label %while.end
+ //
+ auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
+ if (std::distance(LoopBBInsts.begin(), LoopBBInsts.end()) > 7)
+ return false;
+
+ using namespace PatternMatch;
+
+ // The incoming value to the PHI node from the loop should be an add of 1.
+ Value *StartIdx = nullptr;
+ Instruction *Index = nullptr;
+ if (!CurLoop->contains(PN->getIncomingBlock(0))) {
+ StartIdx = PN->getIncomingValue(0);
+ Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
+ } else {
+ StartIdx = PN->getIncomingValue(1);
+ Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
+ }
+
+ // Limit to 32-bit types for now
+ if (!Index || !Index->getType()->isIntegerTy(32) ||
+ !match(Index, m_c_Add(m_Specific(PN), m_One())))
+ return false;
+
+ // If we match the pattern, PN and Index will be replaced with the result of
+ // the cttz.elts intrinsic. If any other instructions are used outside of
+ // the loop, we cannot replace it.
+ for (BasicBlock *BB : LoopBlocks)
+ for (Instruction &I : *BB)
+ if (&I != PN && &I != Index)
+ for (User *U : I.users()) {
+ auto UI = cast<Instruction>(U);
----------------
antoniofrighetto wrote:
```suggestion
if (!CurLoop->contains(cast<Instruction>(U)))
return false;
```
https://github.com/llvm/llvm-project/pull/72273
More information about the llvm-commits
mailing list