[llvm] [LoopInterchange] Support inner-loop simple reductions via UndoSimpleReduction (PR #172970)
Yingying Wang via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 19 00:25:27 PST 2025
https://github.com/buggfg created https://github.com/llvm/llvm-project/pull/172970
Following our [discussion](https://discourse.llvm.org/t/rfc-plan-to-improve-loopinterchange-by-undoing-simple-reductions/89071), I ported GCC’s undo_simple_reduction into LLVM.
**Key changes**
- Implement an UndoSimpleReduction step in LoopInterchange to support simple reductions of inner loop.
- The feature is behind an option `-undo-simple-reduction` and is OFF by default. With the feature off, the pass behaves as before (minimal impact).
- Add a regression test.
**Validation & performance**
- No compile or semantic errors observed on SPEC2006 and SPEC2017 with the new feature enabled for validation.
- With options: `-da-disable-delinearization-checks` and `-undo-simple-reduction`
- SPEC2006 410.bwaves on x86 (Intel i9-11900K, Rocket Lake): **+6%**
- SPEC2017 603.bwaves_s on x86 (Intel i9-11900K, Rocket Lake): **+6%**
- SPEC2006 410.bwaves on SpacemiT Key Stone K1: **+59%**
- SPEC2006 410.bwaves on KMH RTL: **+56%**
- SPEC2017 603.bwaves_s on KMH RTL: **+24%**
**Note**
- UndoSimpleReduction only runs when legality and profitability checks indicate the interchange will actually be performed. If interchange is illegal or not profitable, no undo is applied.
>From 9e354c81a94d2dcaf6c1603dcf31ef0fd453df79 Mon Sep 17 00:00:00 2001
From: buggfg <3171290993 at qq.com>
Date: Fri, 19 Dec 2025 16:11:26 +0800
Subject: [PATCH] Support inner-loop simple reductions via UndoSimpleReduction
Co-Authored-By: ict-ql <168183727+ict-ql at users.noreply.github.com>
Co-Authored-By: Lin Wang <wanglulin at ict.ac.cn>
---
.../lib/Transforms/Scalar/LoopInterchange.cpp | 411 ++++++++++++++----
.../LoopInterchange/simple-reduction.ll | 86 ++++
2 files changed, 414 insertions(+), 83 deletions(-)
create mode 100644 llvm/test/Transforms/LoopInterchange/simple-reduction.ll
diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
index 330b4abb9942f..3da23c7f9ae11 100644
--- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
@@ -31,6 +31,7 @@
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
@@ -120,6 +121,12 @@ static cl::list<RuleTy> Profitabilities(
"Ignore profitability, force interchange (does not "
"work with other options)")));
+// Support for simple reduction of inner loop.
+static cl::opt<bool>
+ EnableUndoSimpleReduction("undo-simple-reduction", cl::init(false),
+ cl::Hidden,
+ cl::desc("Support for simple reduction of inner loop."));
+
#ifndef NDEBUG
static bool noDuplicateRulesAndIgnore(ArrayRef<RuleTy> Rules) {
SmallSet<RuleTy, 4> Set;
@@ -446,8 +453,8 @@ namespace {
class LoopInterchangeLegality {
public:
LoopInterchangeLegality(Loop *Outer, Loop *Inner, ScalarEvolution *SE,
- OptimizationRemarkEmitter *ORE)
- : OuterLoop(Outer), InnerLoop(Inner), SE(SE), ORE(ORE) {}
+ OptimizationRemarkEmitter *ORE, DominatorTree *DT)
+ : OuterLoop(Outer), InnerLoop(Inner), SE(SE), DT(DT), ORE(ORE) {}
/// Check if the loops can be interchanged.
bool canInterchangeLoops(unsigned InnerLoopId, unsigned OuterLoopId,
@@ -475,9 +482,30 @@ class LoopInterchangeLegality {
return HasNoWrapReductions;
}
+ // Record the simple reduction in the inner loop.
+ struct SimpleReduction {
+ // The reduction itself;
+ PHINode *Re;
+ // So far only supports constant initial value.
+ Value *Init;
+ Value *Next;
+ // The Lcssa PHI
+ PHINode *LcssaPhi;
+ // Only supports one user for now
+ // Store reduction result into memory object
+ StoreInst *LcssaStorer;
+ // The memory Location
+ Value *MemRef;
+ Type *ElemTy;
+ };
+
+ const ArrayRef<SimpleReduction *> getInnerSimpleReductions() const {
+ return InnerSimpleReductions;
+ }
+
private:
bool tightlyNested(Loop *Outer, Loop *Inner);
- bool containsUnsafeInstructions(BasicBlock *BB);
+ bool containsUnsafeInstructions(BasicBlock *BB, Instruction *Skip);
/// Discover induction and reduction PHIs in the header of \p L. Induction
/// PHIs are added to \p Inductions, reductions are added to
@@ -487,11 +515,16 @@ class LoopInterchangeLegality {
SmallVector<PHINode *, 8> &Inductions,
Loop *InnerLoop);
+ /// Detect simple-reduction PHIs in the inner loop. Add them to
+ /// InnerSimpleReductions.
+ bool findSimpleReduction(Loop *L, PHINode *Phi,
+ SmallVectorImpl<Instruction *> &HasNoWrapInsts);
+
Loop *OuterLoop;
Loop *InnerLoop;
ScalarEvolution *SE;
-
+ DominatorTree *DT;
/// Interface to emit optimization remarks.
OptimizationRemarkEmitter *ORE;
@@ -506,6 +539,9 @@ class LoopInterchangeLegality {
/// like integer addition/multiplication. Those flags must be dropped when
/// interchanging the loops.
SmallVector<Instruction *, 4> HasNoWrapReductions;
+
+ /// Vector of simple reductions of inner loop.
+ SmallVector<SimpleReduction *, 8> InnerSimpleReductions;
};
/// Manages information utilized by the profitability check for cache. The main
@@ -575,6 +611,7 @@ class LoopInterchangeTransform {
/// Interchange OuterLoop and InnerLoop.
bool transform(ArrayRef<Instruction *> DropNoWrapInsts);
+ void undoSimpleReduction();
void restructureLoops(Loop *NewInner, Loop *NewOuter,
BasicBlock *OrigInnerPreHeader,
BasicBlock *OrigOuterPreHeader);
@@ -693,7 +730,7 @@ struct LoopInterchange {
Loop *InnerLoop = LoopList[InnerLoopId];
LLVM_DEBUG(dbgs() << "Processing InnerLoopId = " << InnerLoopId
<< " and OuterLoopId = " << OuterLoopId << "\n");
- LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, ORE);
+ LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, ORE, DT);
if (!LIL.canInterchangeLoops(InnerLoopId, OuterLoopId, DependencyMatrix)) {
LLVM_DEBUG(dbgs() << "Not interchanging loops. Cannot prove legality.\n");
return false;
@@ -734,8 +771,11 @@ struct LoopInterchange {
} // end anonymous namespace
-bool LoopInterchangeLegality::containsUnsafeInstructions(BasicBlock *BB) {
- return any_of(*BB, [](const Instruction &I) {
+bool LoopInterchangeLegality::containsUnsafeInstructions(BasicBlock *BB,
+ Instruction *Skip) {
+ return any_of(*BB, [Skip](const Instruction &I) {
+ if (&I == Skip)
+ return false;
return I.mayHaveSideEffects() || I.mayReadFromMemory();
});
}
@@ -761,17 +801,27 @@ bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) {
return false;
LLVM_DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch\n");
+
+ // The inner loop simple-reduction pattern requires storing the LCSSA PHI in
+ // the OuterLoop Latch. Therefore, when UndoSimpleReduction is enabled, skip
+ // that store during checks.
+ Instruction *Skip = nullptr;
+ if (EnableUndoSimpleReduction) {
+ if (InnerSimpleReductions.size() == 1)
+ Skip = InnerSimpleReductions[0]->LcssaStorer;
+ }
+
// We do not have any basic block in between now make sure the outer header
// and outer loop latch doesn't contain any unsafe instructions.
- if (containsUnsafeInstructions(OuterLoopHeader) ||
- containsUnsafeInstructions(OuterLoopLatch))
+ if (containsUnsafeInstructions(OuterLoopHeader, Skip) ||
+ containsUnsafeInstructions(OuterLoopLatch, Skip))
return false;
// Also make sure the inner loop preheader does not contain any unsafe
// instructions. Note that all instructions in the preheader will be moved to
// the outer loop header when interchanging.
if (InnerLoopPreHeader != OuterLoopHeader &&
- containsUnsafeInstructions(InnerLoopPreHeader))
+ containsUnsafeInstructions(InnerLoopPreHeader, Skip))
return false;
BasicBlock *InnerLoopExit = InnerLoop->getExitBlock();
@@ -787,7 +837,7 @@ bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) {
// The inner loop exit block does flow to the outer loop latch and not some
// other BBs, now make sure it contains safe instructions, since it will be
// moved into the (new) inner loop after interchange.
- if (containsUnsafeInstructions(InnerLoopExit))
+ if (containsUnsafeInstructions(InnerLoopExit, Skip))
return false;
LLVM_DEBUG(dbgs() << "Loops are perfectly nested\n");
@@ -898,6 +948,77 @@ static Value *followLCSSA(Value *SV) {
return followLCSSA(PHI->getIncomingValue(0));
}
+bool CheckReductionKind(Loop *L, PHINode *PHI,
+ SmallVectorImpl<Instruction *> &HasNoWrapInsts) {
+ RecurrenceDescriptor RD;
+ if (RecurrenceDescriptor::isReductionPHI(PHI, L, RD)) {
+ // Detect floating point reduction only when it can be reordered.
+ if (RD.getExactFPMathInst() != nullptr)
+ return false;
+
+ RecurKind RK = RD.getRecurrenceKind();
+ switch (RK) {
+ case RecurKind::Or:
+ case RecurKind::And:
+ case RecurKind::Xor:
+ case RecurKind::SMin:
+ case RecurKind::SMax:
+ case RecurKind::UMin:
+ case RecurKind::UMax:
+ case RecurKind::FAdd:
+ case RecurKind::FMul:
+ case RecurKind::FMin:
+ case RecurKind::FMax:
+ case RecurKind::FMinimum:
+ case RecurKind::FMaximum:
+ case RecurKind::FMinimumNum:
+ case RecurKind::FMaximumNum:
+ case RecurKind::FMulAdd:
+ case RecurKind::AnyOf:
+ return true;
+
+ // Change the order of integer addition/multiplication may change the
+ // semantics. Consider the following case:
+ //
+ // int A[2][2] = {{ INT_MAX, INT_MAX }, { INT_MIN, INT_MIN }};
+ // int sum = 0;
+ // for (int i = 0; i < 2; i++)
+ // for (int j = 0; j < 2; j++)
+ // sum += A[j][i];
+ //
+ // If the above loops are exchanged, the addition will cause an
+ // overflow. To prevent this, we must drop the nuw/nsw flags from the
+ // addition/multiplication instructions when we actually exchanges the
+ // loops.
+ case RecurKind::Add:
+ case RecurKind::Mul: {
+ unsigned OpCode = RecurrenceDescriptor::getOpcode(RK);
+ SmallVector<Instruction *, 4> Ops = RD.getReductionOpChain(PHI, L);
+
+ // Bail out when we fail to collect reduction instructions chain.
+ if (Ops.empty())
+ return false;
+
+ for (Instruction *I : Ops) {
+ assert(I->getOpcode() == OpCode &&
+ "Expected the instruction to be the reduction operation");
+ (void)OpCode;
+
+ // If the instruction has nuw/nsw flags, we must drop them when the
+ // transformation is actually performed.
+ if (I->hasNoSignedWrap() || I->hasNoUnsignedWrap())
+ HasNoWrapInsts.push_back(I);
+ }
+ return true;
+ }
+
+ default:
+ return false;
+ }
+ } else
+ return false;
+}
+
// Check V's users to see if it is involved in a reduction in L.
static PHINode *
findInnerReductionPhi(Loop *L, Value *V,
@@ -910,72 +1031,12 @@ findInnerReductionPhi(Loop *L, Value *V,
if (PHINode *PHI = dyn_cast<PHINode>(User)) {
if (PHI->getNumIncomingValues() == 1)
continue;
- RecurrenceDescriptor RD;
- if (RecurrenceDescriptor::isReductionPHI(PHI, L, RD)) {
- // Detect floating point reduction only when it can be reordered.
- if (RD.getExactFPMathInst() != nullptr)
- return nullptr;
-
- RecurKind RK = RD.getRecurrenceKind();
- switch (RK) {
- case RecurKind::Or:
- case RecurKind::And:
- case RecurKind::Xor:
- case RecurKind::SMin:
- case RecurKind::SMax:
- case RecurKind::UMin:
- case RecurKind::UMax:
- case RecurKind::FAdd:
- case RecurKind::FMul:
- case RecurKind::FMin:
- case RecurKind::FMax:
- case RecurKind::FMinimum:
- case RecurKind::FMaximum:
- case RecurKind::FMinimumNum:
- case RecurKind::FMaximumNum:
- case RecurKind::FMulAdd:
- case RecurKind::AnyOf:
- return PHI;
-
- // Change the order of integer addition/multiplication may change the
- // semantics. Consider the following case:
- //
- // int A[2][2] = {{ INT_MAX, INT_MAX }, { INT_MIN, INT_MIN }};
- // int sum = 0;
- // for (int i = 0; i < 2; i++)
- // for (int j = 0; j < 2; j++)
- // sum += A[j][i];
- //
- // If the above loops are exchanged, the addition will cause an
- // overflow. To prevent this, we must drop the nuw/nsw flags from the
- // addition/multiplication instructions when we actually exchanges the
- // loops.
- case RecurKind::Add:
- case RecurKind::Mul: {
- unsigned OpCode = RecurrenceDescriptor::getOpcode(RK);
- SmallVector<Instruction *, 4> Ops = RD.getReductionOpChain(PHI, L);
-
- // Bail out when we fail to collect reduction instructions chain.
- if (Ops.empty())
- return nullptr;
-
- for (Instruction *I : Ops) {
- assert(I->getOpcode() == OpCode &&
- "Expected the instruction to be the reduction operation");
- (void)OpCode;
-
- // If the instruction has nuw/nsw flags, we must drop them when the
- // transformation is actually performed.
- if (I->hasNoSignedWrap() || I->hasNoUnsignedWrap())
- HasNoWrapInsts.push_back(I);
- }
- return PHI;
- }
- default:
- return nullptr;
- }
- }
+ if (CheckReductionKind(L, PHI, HasNoWrapInsts))
+ return PHI;
+ else
+ return nullptr;
+
return nullptr;
}
}
@@ -983,6 +1044,116 @@ findInnerReductionPhi(Loop *L, Value *V,
return nullptr;
}
+// Detect and record the simple reduction of the inner loop.
+//
+// innerloop:
+// Re = phi<0.0, Next>
+// ReUser = Re op ...
+// ...
+// Next = ReUser op ...
+// OuterLoopLatch:
+// Lcssa = phi<Next> ; lcssa phi
+// store Lcssa, MemRef ; LcssaStorer
+//
+bool LoopInterchangeLegality::findSimpleReduction(
+ Loop *L, PHINode *Phi, SmallVectorImpl<Instruction *> &HasNoWrapInsts) {
+
+ // Only support undo simple reduction if the loop nest to be interchanged is
+ // the innermostin two loops.
+ if (!L->isInnermost())
+ return false;
+
+ if (Phi->getNumIncomingValues() != 2)
+ return false;
+
+ Value *Init = Phi->getIncomingValueForBlock(L->getLoopPreheader());
+ Value *Next = Phi->getIncomingValueForBlock(L->getLoopLatch());
+
+ // So far only supports constant initial value.
+ auto *ConstInit = dyn_cast<Constant>(Init);
+ if (!ConstInit)
+ return false;
+
+ // The reduction result must live in the inner loop.
+ if (Instruction *I = dyn_cast<Instruction>(Next)) {
+ BasicBlock *BB = I->getParent();
+ if (!L->contains(BB))
+ return false;
+ }
+
+ // The reduction should have only one user.
+ if (!Phi->hasOneUser())
+ return false;
+ Instruction *ReUser = dyn_cast<Instruction>(Phi->getUniqueUndroppableUser());
+ if (!ReUser || !L->contains(ReUser->getParent()))
+ return false;
+
+ // Check the reduction operation.
+ if (!ReUser->isAssociative() || !ReUser->isBinaryOp() ||
+ (ReUser->getOpcode() == Instruction::Sub &&
+ ReUser->getOperand(0) == Phi) ||
+ (ReUser->getOpcode() == Instruction::FSub &&
+ ReUser->getOperand(0) == Phi))
+ return false;
+
+ // Check the reduction kind.
+ if (ReUser != Next && !CheckReductionKind(L, Phi, HasNoWrapInsts))
+ return false;
+
+ // Find lcssa_phi in OuterLoop's Latch
+ if (!L->getExitingBlock())
+ return false;
+ BranchInst *BI = dyn_cast<BranchInst>(L->getExitingBlock()->getTerminator());
+ if (!BI)
+ return false;
+ BasicBlock *ExitBlock =
+ BI->getSuccessor(L->contains(BI->getSuccessor(0)) ? 1 : 0);
+ if (!ExitBlock)
+ return false;
+
+ PHINode *Lcssa = NULL;
+ for (auto *U : Next->users()) {
+ if (auto *P = dyn_cast<PHINode>(U)) {
+ if (P == Phi)
+ continue;
+
+ if (Lcssa == NULL && P->getParent() == ExitBlock &&
+ P->getIncomingValueForBlock(L->getLoopLatch()) == Next)
+ Lcssa = P;
+ } else
+ return false;
+ }
+ if (!Lcssa || !Lcssa->hasOneUser())
+ return false;
+
+ StoreInst *LcssaStorer =
+ dyn_cast<StoreInst>(Lcssa->getUniqueUndroppableUser());
+ if (!LcssaStorer)
+ return false;
+
+ Value *MemRef = LcssaStorer->getOperand(1);
+ Type *ElemTy = LcssaStorer->getOperand(0)->getType();
+
+ // LcssaStorer stores the reduction result in BB. undoSimpleReduction() will
+ // move it into the inner loop. Here we must ensure that the memory reference
+ // and its operands dominate the target block; otherwise the move is unsafe.
+ if (!DT->dominates(dyn_cast<Instruction>(MemRef), ExitBlock))
+ return false;
+
+ // Found a simple reduction of inner loop.
+ SimpleReduction *SR = new SimpleReduction;
+ SR->Re = Phi;
+ SR->Init = Init;
+ SR->Next = Next;
+ SR->LcssaPhi = Lcssa;
+ SR->LcssaStorer = LcssaStorer;
+ SR->MemRef = MemRef;
+ SR->ElemTy = ElemTy;
+
+ InnerSimpleReductions.push_back(&*SR);
+ return true;
+}
+
bool LoopInterchangeLegality::findInductionAndReductions(
Loop *L, SmallVector<PHINode *, 8> &Inductions, Loop *InnerLoop) {
if (!L->getLoopLatch() || !L->getLoopPredecessor())
@@ -995,11 +1166,14 @@ bool LoopInterchangeLegality::findInductionAndReductions(
// PHIs in inner loops need to be part of a reduction in the outer loop,
// discovered when checking the PHIs of the outer loop earlier.
if (!InnerLoop) {
- if (!OuterInnerReductions.count(&PHI)) {
- LLVM_DEBUG(dbgs() << "Inner loop PHI is not part of reductions "
- "across the outer loop.\n");
+ if (OuterInnerReductions.count(&PHI)) {
+ LLVM_DEBUG(dbgs() << "Found a reduction across the outer loop.\n");
+ } else if (EnableUndoSimpleReduction &&
+ findSimpleReduction(L, &PHI, HasNoWrapReductions)) {
+ LLVM_DEBUG(dbgs() << "Found a simple reduction in the inner loop: \n"
+ << PHI << '\n');
+ } else
return false;
- }
} else {
assert(PHI.getNumIncomingValues() == 2 &&
"Phis in loop header should have exactly 2 incoming values");
@@ -1020,6 +1194,10 @@ bool LoopInterchangeLegality::findInductionAndReductions(
}
}
}
+
+ // For now we only support at most one reduction.
+ if (InnerSimpleReductions.size() > 1)
+ return false;
return true;
}
@@ -1115,12 +1293,15 @@ bool LoopInterchangeLegality::findInductions(
// the we are only interested in the final value after the loop).
static bool
areInnerLoopExitPHIsSupported(Loop *InnerL, Loop *OuterL,
- SmallPtrSetImpl<PHINode *> &Reductions) {
+ SmallPtrSetImpl<PHINode *> &Reductions,
+ PHINode *LcssaSimpleRed) {
BasicBlock *InnerExit = OuterL->getUniqueExitBlock();
for (PHINode &PHI : InnerExit->phis()) {
// Reduction lcssa phi will have only 1 incoming block that from loop latch.
if (PHI.getNumIncomingValues() > 1)
return false;
+ if (&PHI == LcssaSimpleRed)
+ return true;
if (any_of(PHI.users(), [&Reductions, OuterL](User *U) {
PHINode *PN = dyn_cast<PHINode>(U);
return !PN ||
@@ -1270,8 +1451,16 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId,
return false;
}
- if (!areInnerLoopExitPHIsSupported(OuterLoop, InnerLoop,
- OuterInnerReductions)) {
+ // The LCSSA PHI for the simple reduction has passed checks before; its user
+ // is a store instruction.
+ PHINode *LcssaSimpleRed = nullptr;
+ if (EnableUndoSimpleReduction) {
+ if (InnerSimpleReductions.size() == 1)
+ LcssaSimpleRed = InnerSimpleReductions[0]->LcssaPhi;
+ }
+
+ if (!areInnerLoopExitPHIsSupported(OuterLoop, InnerLoop, OuterInnerReductions,
+ LcssaSimpleRed)) {
LLVM_DEBUG(dbgs() << "Found unsupported PHI nodes in inner loop exit.\n");
ORE->emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedExitPHI",
@@ -1633,10 +1822,66 @@ void LoopInterchangeTransform::restructureLoops(
SE->forgetLoop(NewOuter);
}
+/*
+ User can write, optimizers can generate simple reduction for inner loop. In
+ order to make interchange valid, we have to undo reduction by moving th
+ initialization and store instructions into the inner loop. So far we only
+ handle cases where the reduction variable is initialized to a constant.
+ For example, below code:
+
+ loop:
+ re = phi<0.0, next>
+ next = re op ...
+ reduc_sum = phi<next> // lcssa phi
+ MEM_REF[idx] = reduc_sum // LcssaStorer
+
+ is transformed into:
+
+ loop:
+ tmp = MEM_REF[idx];
+ new_var = !first_iteration ? tmp : 0.0;
+ next = new_var op ...
+ MEM_REF[idx] = next; // after moving
+
+ In this way the initial const is used in the first iteration of loop.
+*/
+void LoopInterchangeTransform::undoSimpleReduction() {
+
+ auto &InnerSimpleReductions = LIL.getInnerSimpleReductions();
+ LoopInterchangeLegality::SimpleReduction *SR = InnerSimpleReductions[0];
+ BasicBlock *InnerLoopHeader = InnerLoop->getHeader();
+ IRBuilder<> Builder(&*(InnerLoopHeader->getFirstNonPHIIt()));
+
+ // When the reduction is intialized from constant value, we need to add
+ // a stmt loading from the memory object to target basic block in inner
+ // loop during undoing the reduction.
+ Instruction *LoadMem = Builder.CreateLoad(SR->ElemTy, SR->MemRef);
+
+ // Check if it's the first iteration.
+ auto &InductionPHIs = LIL.getInnerLoopInductions();
+ PHINode *IV = InductionPHIs[0];
+ Value *IVInit = IV->getIncomingValueForBlock(InnerLoop->getLoopPreheader());
+ Value *FirstIter = Builder.CreateICmpNE(IV, IVInit, "first.iter");
+
+ // Init new_var to MEM_REF or CONST depending on if it is the first iteration.
+ Value *NewVar = Builder.CreateSelect(FirstIter, LoadMem, SR->Init, "new.var");
+
+ // Replace all uses of reduction var with new variable.
+ SR->Re->replaceAllUsesWith(NewVar);
+
+ // Move store instruction into inner loop, just after reduction next's def.
+ SR->LcssaStorer->setOperand(0, SR->Next);
+ SR->LcssaStorer->moveAfter(dyn_cast<Instruction>(SR->Next));
+}
+
bool LoopInterchangeTransform::transform(
ArrayRef<Instruction *> DropNoWrapInsts) {
bool Transformed = false;
+ auto &InnerSimpleReductions = LIL.getInnerSimpleReductions();
+ if (EnableUndoSimpleReduction && InnerSimpleReductions.size() == 1)
+ undoSimpleReduction();
+
if (InnerLoop->getSubLoops().empty()) {
BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader();
LLVM_DEBUG(dbgs() << "Splitting the inner loop latch\n");
diff --git a/llvm/test/Transforms/LoopInterchange/simple-reduction.ll b/llvm/test/Transforms/LoopInterchange/simple-reduction.ll
new file mode 100644
index 0000000000000..9a4393f827a36
--- /dev/null
+++ b/llvm/test/Transforms/LoopInterchange/simple-reduction.ll
@@ -0,0 +1,86 @@
+; NOTE: Support simple reduction in the inner loop by undoing the simple reduction.
+; RUN: opt < %s -passes="loop(loop-interchange),dce" -undo-simple-reduction -loop-interchange-profitabilities=ignore -S | FileCheck %s
+
+; for (int i = 0; i < n; i++) {
+; s[i] = 0;
+; for (int j = 0; j < n; j++)
+; s[i] = s[i] + a[j][i] * b[j][i];
+; }
+
+define void @func(ptr noalias noundef readonly captures(none) %a, ptr noalias noundef readonly captures(none) %b, ptr noalias noundef writeonly captures(none) %s, i64 noundef %n) {
+; CHECK-LABEL: define void @func(ptr noalias noundef readonly captures(none) %a, ptr noalias noundef readonly captures(none) %b, ptr noalias noundef writeonly captures(none) %s, i64 noundef %n) {
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i64 [[N:%.*]], 0
+; CHECK-NEXT: br i1 [[CMP]], label [[INNERLOOP_PREHEADER:%.*]], label [[EXIT:%.*]]
+; CHECK: outerloop_header.preheader:
+; CHECK-NEXT: br label [[OUTERLOOP_HEADER:%.*]]
+; CHECK: outerloop_header:
+; CHECK-NEXT: [[INDEX_I:%.*]] = phi i64 [ [[I_NEXT:%.*]], [[OUTERLOOP_LATCH:%.*]] ], [ 0, [[OUTERLOOPHEADER_PREHEADER:%.*]] ]
+; CHECK-NEXT: [[ADDR_S:%.*]] = getelementptr inbounds nuw double, ptr %s, i64 [[INDEX_I]]
+; CHECK-NEXT: [[ADDR_A:%.*]] = getelementptr inbounds nuw [100 x double], ptr %a, i64 0, i64 [[INDEX_I]]
+; CHECK-NEXT: [[ADDR_B:%.*]] = getelementptr inbounds nuw [100 x double], ptr %b, i64 0, i64 [[INDEX_I]]
+; CHECK-NEXT: br label [[INNERLOOP_SPLIT1:%.*]]
+; CHECK: innerloop.preheader:
+; CHECK-NEXT: br label [[INNERLOOP:%.*]]
+; CHECK: innerloop:
+; CHECK-NEXT: [[INDEX_J:%.*]] = phi i64 [ [[J_NEXT:%.*]], [[INNERLOOP_SPLIT:%.*]] ], [ 0, [[INNERLOOP_PREHEADER:%.*]] ]
+; CHECK-NEXT: br label [[OUTERLOOPHEADER_PREHEADER:%.*]]
+; CHECK: innerloop.split1:
+; CHECK-NEXT: [[S:%.*]] = load double, ptr [[ADDR_S]], align 8
+; CHECK-NEXT: [[FIRSTITER:%.*]] = icmp ne i64 [[INDEX_J]], 0
+; CHECK-NEXT: [[NEW_VAR:%.*]] = select i1 [[FIRSTITER]], double [[S]], double 0.000000e+00
+; CHECK-NEXT: [[ADDR_A_J_I:%.*]] = getelementptr inbounds nuw [100 x double], ptr [[ADDR_A]], i64 [[INDEX_J]]
+; CHECK-NEXT: [[A_J_I:%.*]] = load double, ptr [[ADDR_A_J_I]], align 8
+; CHECK-NEXT: [[ADDR_B_J_I:%.*]] = getelementptr inbounds nuw [100 x double], ptr [[ADDR_B]], i64 [[INDEX_J]]
+; CHECK-NEXT: [[B_J_I:%.*]] = load double, ptr [[ADDR_B_J_I]], align 8
+; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[B_J_I]], [[A_J_I]]
+; CHECK-NEXT: [[ADD:%.*]] = fadd fast double [[MUL]], [[NEW_VAR]]
+; CHECK-NEXT: store double [[ADD]], ptr [[ADDR_S]], align 8
+; CHECK-NEXT: br label [[OUTERLOOP_LATCH:%.*]]
+; CHECK: innerloop.split:
+; CHECK-NEXT: [[J_NEXT:%.*]] = add nuw nsw i64 [[INDEX_J]], 1
+; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i64 [[J_NEXT]], [[N]]
+; CHECK-NEXT: br i1 [[CMP1]], label [[EXIT_LOOPEXIT:%.*]], label [[INNERLOOP]]
+; CHECK: outerloop_latch:
+; CHECK-NEXT: [[I_NEXT]] = add nuw nsw i64 [[INDEX_I]], 1
+; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i64 [[I_NEXT]], [[N]]
+; CHECK-NEXT: br i1 [[CMP2]], label [[INNERLOOP_SPLIT:%.*]], label [[OUTERLOOP_HEADER]]
+; CHECK: exit.loopexit:
+; CHECK-NEXT: br label [[EXIT:%.*]]
+; CHECK: exit:
+; CHECK-NEXT: ret void
+;
+entry:
+ %cmp = icmp sgt i64 %n, 0
+ br i1 %cmp, label %outerloop_header, label %exit
+
+outerloop_header:
+ %index_i = phi i64 [ 0, %entry ], [ %index_i.next, %outerloop_latch ]
+ %addr_s = getelementptr inbounds nuw double, ptr %s, i64 %index_i
+ %invariant.gep.us = getelementptr inbounds nuw [100 x double], ptr %a, i64 0, i64 %index_i
+ %invariant.gep32.us = getelementptr inbounds nuw [100 x double], ptr %b, i64 0, i64 %index_i
+ br label %innerloop
+
+innerloop:
+ %index_j = phi i64 [ 0, %outerloop_header ], [ %index_j.next, %innerloop ]
+ %reduction = phi double [ 0.000000e+00, %outerloop_header ], [ %add, %innerloop ]
+ %addr_a_j_i = getelementptr inbounds nuw [100 x double], ptr %invariant.gep.us, i64 %index_j
+ %0 = load double, ptr %addr_a_j_i, align 8
+ %addr_b_j_i = getelementptr inbounds nuw [100 x double], ptr %invariant.gep32.us, i64 %index_j
+ %1 = load double, ptr %addr_b_j_i, align 8
+ %mul = fmul fast double %1, %0
+ %add = fadd fast double %mul, %reduction
+ %index_j.next = add nuw nsw i64 %index_j, 1
+ %cond1 = icmp eq i64 %index_j.next, %n
+ br i1 %cond1, label %outerloop_latch, label %innerloop
+
+outerloop_latch:
+ %lcssa = phi double [ %add, %innerloop ]
+ store double %lcssa, ptr %addr_s, align 8
+ %index_i.next = add nuw nsw i64 %index_i, 1
+ %cond2 = icmp eq i64 %index_i.next, %n
+ br i1 %cond2, label %exit, label %outerloop_header
+
+exit:
+ ret void
+}
More information about the llvm-commits
mailing list