[llvm] [RISCV] Introduce the RISCVLoopIdiomRecognizePass (PR #92441)
Min-Yih Hsu via llvm-commits
llvm-commits at lists.llvm.org
Fri May 31 16:56:59 PDT 2024
================
@@ -0,0 +1,732 @@
+//===-------- RISCVLoopIdiomRecognize.cpp - Loop idiom recognition --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "RISCVLoopIdiomRecognize.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsRISCV.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/TargetParser/RISCVTargetParser.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "riscv-loop-idiom"
+
+static cl::opt<bool>
+ DisableAll("riscv-disable-all-loop-idiom", cl::Hidden, cl::init(true),
+ cl::desc("Disable RISCV Loop Idiom Recognize Pass."));
+
+static cl::opt<bool> DisableByteCmp(
+ "disable-riscv-loop-idiom-bytecmp", cl::Hidden, cl::init(false),
+ cl::desc("Proceed with RISCV Loop Idiom Recognize Pass, but do "
+ "not convert byte-compare loop(s)."));
+
+// CustomLoopIdiomLMUL can be used to customize LMUL for vectorizing loops.
+// It uses the exponent value to represent LMUL i.e. 0 -> LMUL 1, 1 -> LMUL 2, 2
+// -> LMUL 4, 3 -> LMUL 8, etc.
+static cl::opt<unsigned>
+ CustomLoopIdiomLMUL("riscv-loop-idiom-lmul", cl::Hidden, cl::init(1),
+ cl::desc("Customize LMUL for vector loop."));
+
+namespace {
+
+class RISCVLoopIdiomRecognize {
+ Loop *CurLoop = nullptr;
+ DominatorTree &DT;
+ LoopInfo &LI;
+ TargetLibraryInfo &TLI;
+ const TargetTransformInfo &TTI;
+ const DataLayout &DL;
+
+public:
+ explicit RISCVLoopIdiomRecognize(DominatorTree &DT, LoopInfo &LI,
+ TargetLibraryInfo &TLI,
+ const TargetTransformInfo &TTI,
+ const DataLayout &DL)
+ : DT(DT), LI(LI), TLI(TLI), TTI(TTI), DL(DL) {}
+
+ bool run(Loop *L);
+
+private:
+ /// \name Countable Loop Idiom Handling
+ /// @{
+
+ bool runOnCountableLoop();
+ bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
+ SmallVectorImpl<BasicBlock *> &ExitBlocks);
+
+ bool recognizeAndTransformByteCompare();
+ Value *expandFindMismatch(IRBuilder<> &Builder, GetElementPtrInst *GEPA,
+ GetElementPtrInst *GEPB, Instruction *Index,
+ Value *Start, Value *MaxLen);
+ void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
+ PHINode *IndPhi, Value *MaxLen, Instruction *Index,
+ Value *Start, bool IncIdx, BasicBlock *FoundBB,
+ BasicBlock *EndBB);
+
+ /// @}
+};
+} // end anonymous namespace
+
+static VectorType *getBestVectorTypeForLoopIdiom(LLVMContext &Ctx) {
+ unsigned LMULExp = std::min(3U, CustomLoopIdiomLMUL.getValue());
+ unsigned VF = (RISCV::RVVBitsPerBlock / 8) << LMULExp;
+ ElementCount EC = ElementCount::getScalable(VF);
+ return VectorType::get(Type::getInt8Ty(Ctx), EC);
+}
+
+PreservedAnalyses
+RISCVLoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM,
+ LoopStandardAnalysisResults &AR,
+ LPMUpdater &) {
+ if (DisableAll)
+ return PreservedAnalyses::all();
+
+ Function &F = *L.getHeader()->getParent();
+ if (F.hasFnAttribute(Attribute::NoImplicitFloat)) {
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName()
+ << " due to its NoImplicitFloat attribute");
+ return PreservedAnalyses::all();
+ }
+
+ // Only enabled on RV64 for now.
+ if (L.getHeader()->getModule()->getDataLayout().getPointerSizeInBits() != 64)
+ return PreservedAnalyses::all();
+
+ // Only enabled when vector extension is present.
+ if (!AR.TTI.supportsScalableVectors())
+ return PreservedAnalyses::all();
+
+ const auto DL = L.getHeader()->getModule()->getDataLayout();
+
+ RISCVLoopIdiomRecognize LIR(AR.DT, AR.LI, AR.TLI, AR.TTI, DL);
+ if (!LIR.run(&L))
+ return PreservedAnalyses::all();
+
+ auto PA = PreservedAnalyses::none();
+ PA.preserve<DominatorTreeAnalysis>();
+ return PA;
+}
+
+//===----------------------------------------------------------------------===//
+//
+// Implementation of RISCVLoopIdiomRecognize
+//
+//===----------------------------------------------------------------------===//
+
+bool RISCVLoopIdiomRecognize::run(Loop *L) {
+ CurLoop = L;
+
+ if (DisableAll)
+ return false;
+
+ // If the loop could not be converted to canonical form, it must have an
+ // indirectbr in it, just give up.
+ if (!L->getLoopPreheader())
+ return false;
+
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F["
+ << CurLoop->getHeader()->getParent()->getName()
+ << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
+
+ return recognizeAndTransformByteCompare();
+}
+
+bool RISCVLoopIdiomRecognize::recognizeAndTransformByteCompare() {
+ if (DisableByteCmp)
+ return false;
+
+ BasicBlock *PH = CurLoop->getLoopPreheader();
+
+ // The preheader should only contain an unconditional branch.
+ if (!PH || &PH->front() != PH->getTerminator())
+ return false;
+
+ using namespace PatternMatch;
+
+ BasicBlock *Header;
+ if (!match(PH->getTerminator(), m_UnconditionalBr(Header)))
+ return false;
+
+ if (Header != CurLoop->getHeader())
+ return false;
+
+ if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
+ return false;
+
+ auto *PN = dyn_cast<PHINode>(&Header->front());
+ if (!PN || PN->getNumIncomingValues() != 2)
+ return false;
+
+ auto LoopBlocks = CurLoop->getBlocks();
+ // The first block in the loop should contain only 4 instructions, e.g.
+ //
+ // while.cond:
+ // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
+ // %inc = add i32 %res.phi, 1
+ // %cmp.not = icmp eq i32 %inc, %n
+ // br i1 %cmp.not, label %while.end, label %while.body
+ //
+ auto CondBBInsts = LoopBlocks[0]->instructionsWithoutDebug();
+ if (std::distance(CondBBInsts.begin(), CondBBInsts.end()) != 4)
+ return false;
+
+ // The second block should contain 7 instructions, e.g.
+ //
+ // while.body:
+ // %idx = zext i32 %inc to i64
+ // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
+ // %load.a = load i8, ptr %idx.a
+ // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
+ // %load.b = load i8, ptr %idx.b
+ // %cmp.not.ld = icmp eq i8 %load.a, %load.b
+ // br i1 %cmp.not.ld, label %while.cond, label %while.end
+ //
+ auto LoopBBInsts = LoopBlocks[1]->instructionsWithoutDebug();
----------------
mshockwave wrote:
Fixed in the new PR(s).
https://github.com/llvm/llvm-project/pull/92441
More information about the llvm-commits
mailing list