[llvm] [RISCV] Improve loop by extract reduction instruction (PR #179215)

via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 8 23:07:05 PST 2026


================
@@ -0,0 +1,213 @@
+//===-------- LoopReduceMotion.cpp - Loop Reduce Motion Optimization ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This pass is designed to sink `ReduceCall` operations out of loops to reduce
+// the number of instructions within the loop body.
+//
+// Below is the target pattern to be matched and the resulting pattern
+// after the transformation.
+//
+// before                    | after
+// ------                    | ------
+// loop:                     | loop:
+//   ...                     |   ...
+//   vc = vecbin va, vb      |   vc = vecbin va, vb
+//   d = reduce_add vc       |   vsum = vadd vsum, vc
+//   sum = add sum, d        |   ...
+//   ...                     |   ...
+// exit:                     | exit:
+//   value = sum             |   d = reduce_add sum
+//   ...                     |   value = d
+//   ...                     |   ...
+//   ret                     |   ret
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Vectorize/LoopReduceMotion.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Plugins/PassPlugin.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+
+#define DEBUG_TYPE "loop-reduce-motion"
+
+using namespace llvm;
+
+class LoopReduceMotion : public FunctionPass {
+  LoopReduceMotionPass Impl;
+
+public:
+  static char ID;
+
+  LoopReduceMotion() : FunctionPass(ID) {}
+
+  StringRef getPassName() const override { return "Loop Reduce Motion Pass"; }
+
+  bool runOnFunction(Function &F) override;
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addRequired<LoopInfoWrapperPass>();
+    AU.setPreservesCFG();
+  }
+};
+
+char LoopReduceMotion::ID = 0;
+
+PreservedAnalyses LoopReduceMotionPass::run(Function &F,
+                                            FunctionAnalysisManager &FAM) {
+  LoopInfo &LI = FAM.getResult<LoopAnalysis>(F);
+  DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
+  bool Changed = false;
+  for (Loop *L : LI) {
+    Changed |= matchAndTransform(*L, DT, LI);
+  }
+  if (!Changed)
+    return PreservedAnalyses::all();
+  return PreservedAnalyses::none();
+}
+
+bool LoopReduceMotion::runOnFunction(Function &F) {
+  if (skipFunction(F))
+    return false;
+
+  auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
+  if (!TPC)
+    return false;
+
+  LLVM_DEBUG(dbgs() << "*** " << getPassName() << ": " << F.getName() << "\n");
+
+  DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+  LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+  bool Changed = false;
+  for (Loop *L : LI) {
+    Changed |= Impl.matchAndTransform(*L, DT, LI);
+  }
+  if (!Changed)
+    return false;
+
+  return true;
+}
+
+bool LoopReduceMotionPass::matchAndTransform(Loop &L, DominatorTree &DT,
+                                             LoopInfo &LI) {
+  BasicBlock *Header = L.getHeader();
+  BasicBlock *Latch = L.getLoopLatch();
+  BasicBlock *ExitBlock = L.getExitBlock();
+  if (!Header || !Latch || !ExitBlock) {
+    LLVM_DEBUG(dbgs() << "LRM: Skipping loop " << Header->getName()
+                      << " because it is not a valid loop.\n");
+    return false;
+  }
+  BasicBlock *Preheader = L.getLoopPreheader();
+  if (!Preheader) {
+    Preheader = InsertPreheaderForLoop(&L, &DT, &LI, nullptr, false);
+    if (!Preheader) {
+      LLVM_DEBUG(dbgs() << "LRM: Failed to create a preheader for loop "
+                        << Header->getName() << ".\n");
+      return false;
+    }
+  }
+  for (PHINode &PN : Header->phis()) {
+    if (!PN.getType()->isIntegerTy())
+      continue;
+
+    RecurrenceDescriptor RecDesc;
+    if (!RecurrenceDescriptor::isReductionPHI(&PN, &L, RecDesc))
+      continue;
+
+    if (RecDesc.getRecurrenceKind() != RecurKind::Add)
+      continue;
+
+    Value *RecurrenceValueFromPHI = PN.getIncomingValueForBlock(Latch);
+    Instruction *RecurrenceInst = dyn_cast<Instruction>(RecurrenceValueFromPHI);
+    if (!RecurrenceInst || RecurrenceInst->getNumOperands() != 2)
+      continue;
+
+    Value *RecurrenceValue = RecurrenceInst->getOperand(0) == &PN
+                                 ? RecurrenceInst->getOperand(1)
+                                 : RecurrenceInst->getOperand(0);
+
+    CallInst *ReduceCall = dyn_cast<CallInst>(RecurrenceValue);
+    if (!ReduceCall)
+      continue;
+    Function *CalledFunc = ReduceCall->getCalledFunction();
+
+    if (!CalledFunc || !CalledFunc->isIntrinsic() ||
+        !(CalledFunc->getIntrinsicID() == Intrinsic::vector_reduce_add))
+      continue;
+
+    Value *ReduceOperand = ReduceCall->getArgOperand(0);
+    Instruction *VecBin = dyn_cast<Instruction>(ReduceOperand);
+    if (!VecBin || (VecBin->getOpcode() != Instruction::Sub &&
+                    VecBin->getOpcode() != Instruction::Add))
+      continue;
+    // pattern match success
+    LLVM_DEBUG(dbgs() << "Found pattern to optimize in loop "
+                      << Header->getName() << "!\n");
+
+    VectorType *VecTy = cast<VectorType>(VecBin->getType());
+    Value *VecZero = ConstantInt::get(VecTy, 0);
+
+    // build new Vector Add to replace Scalar Add
+    IRBuilder<> HeaderBuilder(Header, Header->getFirstNonPHIIt());
+    PHINode *VecSumPhi = HeaderBuilder.CreatePHI(VecTy, 2, "vec.sum.phi");
+    VecSumPhi->addIncoming(VecZero, Preheader);
+    IRBuilder<> BodyBuilder(RecurrenceInst);
+    Value *NewVecAdd = BodyBuilder.CreateAdd(VecSumPhi, VecBin, "vec.sum.next");
+    VecSumPhi->addIncoming(NewVecAdd, Latch);
+
+    // build landingPad for reduce add out of loop
+    BasicBlock *ExitingBlock =
+        Latch->getTerminator()->getSuccessor(0) == Header ? Latch : Header;
+    if (!L.isLoopExiting(ExitingBlock)) {
+      ExitingBlock = Header;
+    }
+    BasicBlock *LandingPad = SplitEdge(ExitingBlock, ExitBlock, &DT, &LI);
+    LandingPad->setName("loop.exit.landing");
+    IRBuilder<> LandingPadBuilder(LandingPad->getTerminator());
+    Value *ScalarTotalSum = LandingPadBuilder.CreateCall(
+        ReduceCall->getCalledFunction(), NewVecAdd, "scalar.total.sum");
+    Value *PreheaderValue = PN.getIncomingValueForBlock(Preheader);
+    Value *LastAdd =
+        PreheaderValue
+            ? LandingPadBuilder.CreateAdd(PreheaderValue, ScalarTotalSum)
+            : ScalarTotalSum;
+
+    // replace use of phi and erase use empty value
+    if (!PN.use_empty())
+      PN.replaceAllUsesWith(PoisonValue::get(PN.getType()));
+    if (PN.use_empty())
+      PN.eraseFromParent();
+
+    Instruction *FinalNode = dyn_cast<Instruction>(LastAdd);
+    if (!FinalNode)
+      return false;
+    RecurrenceInst->replaceAllUsesWith(FinalNode);
+
+    if (RecurrenceInst->use_empty())
+      RecurrenceInst->eraseFromParent();
+    if (ReduceCall->use_empty())
+      ReduceCall->eraseFromParent();
----------------
Anjian-Wen wrote:

fixed!

https://github.com/llvm/llvm-project/pull/179215


More information about the llvm-commits mailing list