[llvm] [LoopInterchange] Support inner-loop simple reductions via UndoSimpleReduction (PR #172970)

Yingying Wang via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 19 00:25:27 PST 2025


https://github.com/buggfg created https://github.com/llvm/llvm-project/pull/172970

Following our [discussion](https://discourse.llvm.org/t/rfc-plan-to-improve-loopinterchange-by-undoing-simple-reductions/89071),  I ported GCC’s undo_simple_reduction into LLVM. 

**Key changes**

- Implement an UndoSimpleReduction step in LoopInterchange to support simple reductions of inner loop.
- The feature is behind an option `-undo-simple-reduction` and is OFF by default. With the feature off, the pass behaves as before (minimal impact).
- Add a regression test.

**Validation & performance**

- No compile or semantic errors observed on SPEC2006 and SPEC2017 with the new feature enabled for validation.
- With options: `-da-disable-delinearization-checks` and `-undo-simple-reduction`
  - SPEC2006 410.bwaves on x86 (Intel i9-11900K, Rocket Lake): **+6%**
  - SPEC2017 603.bwaves_s on x86 (Intel i9-11900K, Rocket Lake): **+6%**
  - SPEC2006 410.bwaves on SpacemiT Key Stone K1: **+59%**
  - SPEC2006 410.bwaves on KMH RTL: **+56%**
  - SPEC2017 603.bwaves_s on KMH RTL: **+24%**

**Note**

- UndoSimpleReduction only runs when legality and profitability checks indicate the interchange will actually be performed. If interchange is illegal or not profitable, no undo is applied.  

>From 9e354c81a94d2dcaf6c1603dcf31ef0fd453df79 Mon Sep 17 00:00:00 2001
From: buggfg <3171290993 at qq.com>
Date: Fri, 19 Dec 2025 16:11:26 +0800
Subject: [PATCH] Support inner-loop simple reductions via UndoSimpleReduction

Co-Authored-By: ict-ql <168183727+ict-ql at users.noreply.github.com>
Co-Authored-By: Lin Wang <wanglulin at ict.ac.cn>
---
 .../lib/Transforms/Scalar/LoopInterchange.cpp | 411 ++++++++++++++----
 .../LoopInterchange/simple-reduction.ll       |  86 ++++
 2 files changed, 414 insertions(+), 83 deletions(-)
 create mode 100644 llvm/test/Transforms/LoopInterchange/simple-reduction.ll

diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
index 330b4abb9942f..3da23c7f9ae11 100644
--- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp
@@ -31,6 +31,7 @@
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
@@ -120,6 +121,12 @@ static cl::list<RuleTy> Profitabilities(
                           "Ignore profitability, force interchange (does not "
                           "work with other options)")));
 
+// Support for simple reduction of inner loop.
+static cl::opt<bool>
+    EnableUndoSimpleReduction("undo-simple-reduction", cl::init(false),
+                              cl::Hidden,
+                              cl::desc("Support for simple reduction of inner loop."));
+
 #ifndef NDEBUG
 static bool noDuplicateRulesAndIgnore(ArrayRef<RuleTy> Rules) {
   SmallSet<RuleTy, 4> Set;
@@ -446,8 +453,8 @@ namespace {
 class LoopInterchangeLegality {
 public:
   LoopInterchangeLegality(Loop *Outer, Loop *Inner, ScalarEvolution *SE,
-                          OptimizationRemarkEmitter *ORE)
-      : OuterLoop(Outer), InnerLoop(Inner), SE(SE), ORE(ORE) {}
+                          OptimizationRemarkEmitter *ORE, DominatorTree *DT)
+      : OuterLoop(Outer), InnerLoop(Inner), SE(SE), DT(DT), ORE(ORE) {}
 
   /// Check if the loops can be interchanged.
   bool canInterchangeLoops(unsigned InnerLoopId, unsigned OuterLoopId,
@@ -475,9 +482,30 @@ class LoopInterchangeLegality {
     return HasNoWrapReductions;
   }
 
+  // Record the simple reduction in the inner loop.
+  struct SimpleReduction {
+    // The reduction itself;
+    PHINode *Re;
+    // So far only supports constant initial value.
+    Value *Init;
+    Value *Next;
+    // The Lcssa PHI
+    PHINode *LcssaPhi;
+    // Only supports one user for now
+    // Store reduction result into memory object
+    StoreInst *LcssaStorer;
+    // The memory Location
+    Value *MemRef;
+    Type *ElemTy;
+  };
+
+  const ArrayRef<SimpleReduction *> getInnerSimpleReductions() const {
+    return InnerSimpleReductions;
+  }
+
 private:
   bool tightlyNested(Loop *Outer, Loop *Inner);
-  bool containsUnsafeInstructions(BasicBlock *BB);
+  bool containsUnsafeInstructions(BasicBlock *BB, Instruction *Skip);
 
   /// Discover induction and reduction PHIs in the header of \p L. Induction
   /// PHIs are added to \p Inductions, reductions are added to
@@ -487,11 +515,16 @@ class LoopInterchangeLegality {
                                   SmallVector<PHINode *, 8> &Inductions,
                                   Loop *InnerLoop);
 
+  /// Detect simple-reduction PHIs in the inner loop. Add them to
+  /// InnerSimpleReductions.
+  bool findSimpleReduction(Loop *L, PHINode *Phi,
+                           SmallVectorImpl<Instruction *> &HasNoWrapInsts);
+
   Loop *OuterLoop;
   Loop *InnerLoop;
 
   ScalarEvolution *SE;
-
+  DominatorTree *DT;
   /// Interface to emit optimization remarks.
   OptimizationRemarkEmitter *ORE;
 
@@ -506,6 +539,9 @@ class LoopInterchangeLegality {
   /// like integer addition/multiplication. Those flags must be dropped when
   /// interchanging the loops.
   SmallVector<Instruction *, 4> HasNoWrapReductions;
+
+  /// Vector of simple reductions of inner loop.
+  SmallVector<SimpleReduction *, 8> InnerSimpleReductions;
 };
 
 /// Manages information utilized by the profitability check for cache. The main
@@ -575,6 +611,7 @@ class LoopInterchangeTransform {
 
   /// Interchange OuterLoop and InnerLoop.
   bool transform(ArrayRef<Instruction *> DropNoWrapInsts);
+  void undoSimpleReduction();
   void restructureLoops(Loop *NewInner, Loop *NewOuter,
                         BasicBlock *OrigInnerPreHeader,
                         BasicBlock *OrigOuterPreHeader);
@@ -693,7 +730,7 @@ struct LoopInterchange {
     Loop *InnerLoop = LoopList[InnerLoopId];
     LLVM_DEBUG(dbgs() << "Processing InnerLoopId = " << InnerLoopId
                       << " and OuterLoopId = " << OuterLoopId << "\n");
-    LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, ORE);
+    LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, ORE, DT);
     if (!LIL.canInterchangeLoops(InnerLoopId, OuterLoopId, DependencyMatrix)) {
       LLVM_DEBUG(dbgs() << "Not interchanging loops. Cannot prove legality.\n");
       return false;
@@ -734,8 +771,11 @@ struct LoopInterchange {
 
 } // end anonymous namespace
 
-bool LoopInterchangeLegality::containsUnsafeInstructions(BasicBlock *BB) {
-  return any_of(*BB, [](const Instruction &I) {
+bool LoopInterchangeLegality::containsUnsafeInstructions(BasicBlock *BB,
+                                                         Instruction *Skip) {
+  return any_of(*BB, [Skip](const Instruction &I) {
+    if (&I == Skip)
+      return false;
     return I.mayHaveSideEffects() || I.mayReadFromMemory();
   });
 }
@@ -761,17 +801,27 @@ bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) {
       return false;
 
   LLVM_DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch\n");
+
+  // The inner loop simple-reduction pattern requires storing the LCSSA PHI in
+  // the OuterLoop Latch. Therefore, when UndoSimpleReduction is enabled, skip
+  // that store during checks.
+  Instruction *Skip = nullptr;
+  if (EnableUndoSimpleReduction) {
+    if (InnerSimpleReductions.size() == 1)
+      Skip = InnerSimpleReductions[0]->LcssaStorer;
+  }
+
   // We do not have any basic block in between now make sure the outer header
   // and outer loop latch doesn't contain any unsafe instructions.
-  if (containsUnsafeInstructions(OuterLoopHeader) ||
-      containsUnsafeInstructions(OuterLoopLatch))
+  if (containsUnsafeInstructions(OuterLoopHeader, Skip) ||
+      containsUnsafeInstructions(OuterLoopLatch, Skip))
     return false;
 
   // Also make sure the inner loop preheader does not contain any unsafe
   // instructions. Note that all instructions in the preheader will be moved to
   // the outer loop header when interchanging.
   if (InnerLoopPreHeader != OuterLoopHeader &&
-      containsUnsafeInstructions(InnerLoopPreHeader))
+      containsUnsafeInstructions(InnerLoopPreHeader, Skip))
     return false;
 
   BasicBlock *InnerLoopExit = InnerLoop->getExitBlock();
@@ -787,7 +837,7 @@ bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) {
   // The inner loop exit block does flow to the outer loop latch and not some
   // other BBs, now make sure it contains safe instructions, since it will be
   // moved into the (new) inner loop after interchange.
-  if (containsUnsafeInstructions(InnerLoopExit))
+  if (containsUnsafeInstructions(InnerLoopExit, Skip))
     return false;
 
   LLVM_DEBUG(dbgs() << "Loops are perfectly nested\n");
@@ -898,6 +948,77 @@ static Value *followLCSSA(Value *SV) {
   return followLCSSA(PHI->getIncomingValue(0));
 }
 
+bool CheckReductionKind(Loop *L, PHINode *PHI,
+                        SmallVectorImpl<Instruction *> &HasNoWrapInsts) {
+  RecurrenceDescriptor RD;
+  if (RecurrenceDescriptor::isReductionPHI(PHI, L, RD)) {
+    // Detect floating point reduction only when it can be reordered.
+    if (RD.getExactFPMathInst() != nullptr)
+      return false;
+
+    RecurKind RK = RD.getRecurrenceKind();
+    switch (RK) {
+    case RecurKind::Or:
+    case RecurKind::And:
+    case RecurKind::Xor:
+    case RecurKind::SMin:
+    case RecurKind::SMax:
+    case RecurKind::UMin:
+    case RecurKind::UMax:
+    case RecurKind::FAdd:
+    case RecurKind::FMul:
+    case RecurKind::FMin:
+    case RecurKind::FMax:
+    case RecurKind::FMinimum:
+    case RecurKind::FMaximum:
+    case RecurKind::FMinimumNum:
+    case RecurKind::FMaximumNum:
+    case RecurKind::FMulAdd:
+    case RecurKind::AnyOf:
+      return true;
+
+    // Change the order of integer addition/multiplication may change the
+    // semantics. Consider the following case:
+    //
+    //  int A[2][2] = {{ INT_MAX, INT_MAX }, { INT_MIN, INT_MIN }};
+    //  int sum = 0;
+    //  for (int i = 0; i < 2; i++)
+    //    for (int j = 0; j < 2; j++)
+    //      sum += A[j][i];
+    //
+    // If the above loops are exchanged, the addition will cause an
+    // overflow. To prevent this, we must drop the nuw/nsw flags from the
+    // addition/multiplication instructions when we actually exchanges the
+    // loops.
+    case RecurKind::Add:
+    case RecurKind::Mul: {
+      unsigned OpCode = RecurrenceDescriptor::getOpcode(RK);
+      SmallVector<Instruction *, 4> Ops = RD.getReductionOpChain(PHI, L);
+
+      // Bail out when we fail to collect reduction instructions chain.
+      if (Ops.empty())
+        return false;
+
+      for (Instruction *I : Ops) {
+        assert(I->getOpcode() == OpCode &&
+               "Expected the instruction to be the reduction operation");
+        (void)OpCode;
+
+        // If the instruction has nuw/nsw flags, we must drop them when the
+        // transformation is actually performed.
+        if (I->hasNoSignedWrap() || I->hasNoUnsignedWrap())
+          HasNoWrapInsts.push_back(I);
+      }
+      return true;
+    }
+
+    default:
+      return false;
+    }
+  } else
+    return false;
+}
+
 // Check V's users to see if it is involved in a reduction in L.
 static PHINode *
 findInnerReductionPhi(Loop *L, Value *V,
@@ -910,72 +1031,12 @@ findInnerReductionPhi(Loop *L, Value *V,
     if (PHINode *PHI = dyn_cast<PHINode>(User)) {
       if (PHI->getNumIncomingValues() == 1)
         continue;
-      RecurrenceDescriptor RD;
-      if (RecurrenceDescriptor::isReductionPHI(PHI, L, RD)) {
-        // Detect floating point reduction only when it can be reordered.
-        if (RD.getExactFPMathInst() != nullptr)
-          return nullptr;
-
-        RecurKind RK = RD.getRecurrenceKind();
-        switch (RK) {
-        case RecurKind::Or:
-        case RecurKind::And:
-        case RecurKind::Xor:
-        case RecurKind::SMin:
-        case RecurKind::SMax:
-        case RecurKind::UMin:
-        case RecurKind::UMax:
-        case RecurKind::FAdd:
-        case RecurKind::FMul:
-        case RecurKind::FMin:
-        case RecurKind::FMax:
-        case RecurKind::FMinimum:
-        case RecurKind::FMaximum:
-        case RecurKind::FMinimumNum:
-        case RecurKind::FMaximumNum:
-        case RecurKind::FMulAdd:
-        case RecurKind::AnyOf:
-          return PHI;
-
-        // Change the order of integer addition/multiplication may change the
-        // semantics. Consider the following case:
-        //
-        //  int A[2][2] = {{ INT_MAX, INT_MAX }, { INT_MIN, INT_MIN }};
-        //  int sum = 0;
-        //  for (int i = 0; i < 2; i++)
-        //    for (int j = 0; j < 2; j++)
-        //      sum += A[j][i];
-        //
-        // If the above loops are exchanged, the addition will cause an
-        // overflow. To prevent this, we must drop the nuw/nsw flags from the
-        // addition/multiplication instructions when we actually exchanges the
-        // loops.
-        case RecurKind::Add:
-        case RecurKind::Mul: {
-          unsigned OpCode = RecurrenceDescriptor::getOpcode(RK);
-          SmallVector<Instruction *, 4> Ops = RD.getReductionOpChain(PHI, L);
-
-          // Bail out when we fail to collect reduction instructions chain.
-          if (Ops.empty())
-            return nullptr;
-
-          for (Instruction *I : Ops) {
-            assert(I->getOpcode() == OpCode &&
-                   "Expected the instruction to be the reduction operation");
-            (void)OpCode;
-
-            // If the instruction has nuw/nsw flags, we must drop them when the
-            // transformation is actually performed.
-            if (I->hasNoSignedWrap() || I->hasNoUnsignedWrap())
-              HasNoWrapInsts.push_back(I);
-          }
-          return PHI;
-        }
 
-        default:
-          return nullptr;
-        }
-      }
+      if (CheckReductionKind(L, PHI, HasNoWrapInsts))
+        return PHI;
+      else
+        return nullptr;
+
       return nullptr;
     }
   }
@@ -983,6 +1044,116 @@ findInnerReductionPhi(Loop *L, Value *V,
   return nullptr;
 }
 
+// Detect and record the simple reduction of the inner loop.
+//
+//    innerloop:
+//        Re = phi<0.0, Next>
+//        ReUser = Re op ...
+//        ...
+//        Next = ReUser op ...
+//    OuterLoopLatch:
+//        Lcssa = phi<Next>    ; lcssa phi
+//        store Lcssa, MemRef  ; LcssaStorer
+//
+bool LoopInterchangeLegality::findSimpleReduction(
+    Loop *L, PHINode *Phi, SmallVectorImpl<Instruction *> &HasNoWrapInsts) {
+
+  // Only support undo simple reduction if the loop nest to be interchanged is
+  // the innermostin two loops.
+  if (!L->isInnermost())
+    return false;
+
+  if (Phi->getNumIncomingValues() != 2)
+    return false;
+
+  Value *Init = Phi->getIncomingValueForBlock(L->getLoopPreheader());
+  Value *Next = Phi->getIncomingValueForBlock(L->getLoopLatch());
+
+  // So far only supports constant initial value.
+  auto *ConstInit = dyn_cast<Constant>(Init);
+  if (!ConstInit)
+    return false;
+
+  // The reduction result must live in the inner loop.
+  if (Instruction *I = dyn_cast<Instruction>(Next)) {
+    BasicBlock *BB = I->getParent();
+    if (!L->contains(BB))
+      return false;
+  }
+
+  // The reduction should have only one user.
+  if (!Phi->hasOneUser())
+    return false;
+  Instruction *ReUser = dyn_cast<Instruction>(Phi->getUniqueUndroppableUser());
+  if (!ReUser || !L->contains(ReUser->getParent()))
+    return false;
+
+  // Check the reduction operation.
+  if (!ReUser->isAssociative() || !ReUser->isBinaryOp() ||
+      (ReUser->getOpcode() == Instruction::Sub &&
+       ReUser->getOperand(0) == Phi) ||
+      (ReUser->getOpcode() == Instruction::FSub &&
+       ReUser->getOperand(0) == Phi))
+    return false;
+
+  // Check the reduction kind.
+  if (ReUser != Next && !CheckReductionKind(L, Phi, HasNoWrapInsts))
+    return false;
+
+  // Find lcssa_phi in OuterLoop's Latch
+  if (!L->getExitingBlock())
+    return false;
+  BranchInst *BI = dyn_cast<BranchInst>(L->getExitingBlock()->getTerminator());
+  if (!BI)
+    return false;
+  BasicBlock *ExitBlock =
+      BI->getSuccessor(L->contains(BI->getSuccessor(0)) ? 1 : 0);
+  if (!ExitBlock)
+    return false;
+
+  PHINode *Lcssa = NULL;
+  for (auto *U : Next->users()) {
+    if (auto *P = dyn_cast<PHINode>(U)) {
+      if (P == Phi)
+        continue;
+
+      if (Lcssa == NULL && P->getParent() == ExitBlock &&
+          P->getIncomingValueForBlock(L->getLoopLatch()) == Next)
+        Lcssa = P;
+    } else
+      return false;
+  }
+  if (!Lcssa || !Lcssa->hasOneUser())
+    return false;
+
+  StoreInst *LcssaStorer =
+      dyn_cast<StoreInst>(Lcssa->getUniqueUndroppableUser());
+  if (!LcssaStorer)
+    return false;
+
+  Value *MemRef = LcssaStorer->getOperand(1);
+  Type *ElemTy = LcssaStorer->getOperand(0)->getType();
+
+  // LcssaStorer stores the reduction result in BB. undoSimpleReduction() will
+  // move it into the inner loop. Here we must ensure that the memory reference
+  // and its operands dominate the target block; otherwise the move is unsafe.
+  if (!DT->dominates(dyn_cast<Instruction>(MemRef), ExitBlock))
+    return false;
+
+  // Found a simple reduction of inner loop.
+  SimpleReduction *SR = new SimpleReduction;
+  SR->Re = Phi;
+  SR->Init = Init;
+  SR->Next = Next;
+  SR->LcssaPhi = Lcssa;
+  SR->LcssaStorer = LcssaStorer;
+  SR->MemRef = MemRef;
+  SR->ElemTy = ElemTy;
+
+  InnerSimpleReductions.push_back(&*SR);
+  return true;
+}
+
 bool LoopInterchangeLegality::findInductionAndReductions(
     Loop *L, SmallVector<PHINode *, 8> &Inductions, Loop *InnerLoop) {
   if (!L->getLoopLatch() || !L->getLoopPredecessor())
@@ -995,11 +1166,14 @@ bool LoopInterchangeLegality::findInductionAndReductions(
       // PHIs in inner loops need to be part of a reduction in the outer loop,
       // discovered when checking the PHIs of the outer loop earlier.
       if (!InnerLoop) {
-        if (!OuterInnerReductions.count(&PHI)) {
-          LLVM_DEBUG(dbgs() << "Inner loop PHI is not part of reductions "
-                               "across the outer loop.\n");
+        if (OuterInnerReductions.count(&PHI)) {
+          LLVM_DEBUG(dbgs() << "Found a reduction across the outer loop.\n");
+        } else if (EnableUndoSimpleReduction &&
+                   findSimpleReduction(L, &PHI, HasNoWrapReductions)) {
+          LLVM_DEBUG(dbgs() << "Found a simple reduction in the inner loop: \n"
+                            << PHI << '\n');
+        } else
           return false;
-        }
       } else {
         assert(PHI.getNumIncomingValues() == 2 &&
                "Phis in loop header should have exactly 2 incoming values");
@@ -1020,6 +1194,10 @@ bool LoopInterchangeLegality::findInductionAndReductions(
       }
     }
   }
+
+  // For now we only support at most one reduction.
+  if (InnerSimpleReductions.size() > 1)
+    return false;
   return true;
 }
 
@@ -1115,12 +1293,15 @@ bool LoopInterchangeLegality::findInductions(
 // the we are only interested in the final value after the loop).
 static bool
 areInnerLoopExitPHIsSupported(Loop *InnerL, Loop *OuterL,
-                              SmallPtrSetImpl<PHINode *> &Reductions) {
+                              SmallPtrSetImpl<PHINode *> &Reductions,
+                              PHINode *LcssaSimpleRed) {
   BasicBlock *InnerExit = OuterL->getUniqueExitBlock();
   for (PHINode &PHI : InnerExit->phis()) {
     // Reduction lcssa phi will have only 1 incoming block that from loop latch.
     if (PHI.getNumIncomingValues() > 1)
       return false;
+    if (&PHI == LcssaSimpleRed)
+      return true;
     if (any_of(PHI.users(), [&Reductions, OuterL](User *U) {
           PHINode *PN = dyn_cast<PHINode>(U);
           return !PN ||
@@ -1270,8 +1451,16 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId,
     return false;
   }
 
-  if (!areInnerLoopExitPHIsSupported(OuterLoop, InnerLoop,
-                                     OuterInnerReductions)) {
+  // The LCSSA PHI for the simple reduction has passed checks before; its user
+  // is a store instruction.
+  PHINode *LcssaSimpleRed = nullptr;
+  if (EnableUndoSimpleReduction) {
+    if (InnerSimpleReductions.size() == 1)
+      LcssaSimpleRed = InnerSimpleReductions[0]->LcssaPhi;
+  }
+
+  if (!areInnerLoopExitPHIsSupported(OuterLoop, InnerLoop, OuterInnerReductions,
+                                     LcssaSimpleRed)) {
     LLVM_DEBUG(dbgs() << "Found unsupported PHI nodes in inner loop exit.\n");
     ORE->emit([&]() {
       return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedExitPHI",
@@ -1633,10 +1822,66 @@ void LoopInterchangeTransform::restructureLoops(
   SE->forgetLoop(NewOuter);
 }
 
+/*
+  User can write, optimizers can generate simple reduction for inner loop. In
+  order to make interchange valid, we have to undo reduction by moving th
+  initialization and store instructions into the inner loop. So far we only
+  handle cases where the reduction variable is initialized to a constant.
+  For example, below code:
+
+  loop:
+    re = phi<0.0, next>
+    next = re op ...
+  reduc_sum = phi<next>       // lcssa phi
+  MEM_REF[idx] = reduc_sum		// LcssaStorer
+
+  is transformed into:
+
+  loop:
+    tmp = MEM_REF[idx];
+    new_var = !first_iteration ? tmp : 0.0;
+    next = new_var op ...
+    MEM_REF[idx] = next;		// after moving
+
+  In this way the initial const is used in the first iteration of loop.
+*/
+void LoopInterchangeTransform::undoSimpleReduction() {
+
+  auto &InnerSimpleReductions = LIL.getInnerSimpleReductions();
+  LoopInterchangeLegality::SimpleReduction *SR = InnerSimpleReductions[0];
+  BasicBlock *InnerLoopHeader = InnerLoop->getHeader();
+  IRBuilder<> Builder(&*(InnerLoopHeader->getFirstNonPHIIt()));
+
+  // When the reduction is intialized from constant value, we need to add
+  // a stmt loading from the memory object to target basic block in inner
+  // loop during undoing the reduction.
+  Instruction *LoadMem = Builder.CreateLoad(SR->ElemTy, SR->MemRef);
+
+  // Check if it's the first iteration.
+  auto &InductionPHIs = LIL.getInnerLoopInductions();
+  PHINode *IV = InductionPHIs[0];
+  Value *IVInit = IV->getIncomingValueForBlock(InnerLoop->getLoopPreheader());
+  Value *FirstIter = Builder.CreateICmpNE(IV, IVInit, "first.iter");
+
+  // Init new_var to MEM_REF or CONST depending on if it is the first iteration.
+  Value *NewVar = Builder.CreateSelect(FirstIter, LoadMem, SR->Init, "new.var");
+
+  // Replace all uses of reduction var with new variable.
+  SR->Re->replaceAllUsesWith(NewVar);
+
+  // Move store instruction into inner loop, just after reduction next's def.
+  SR->LcssaStorer->setOperand(0, SR->Next);
+  SR->LcssaStorer->moveAfter(dyn_cast<Instruction>(SR->Next));
+}
+
 bool LoopInterchangeTransform::transform(
     ArrayRef<Instruction *> DropNoWrapInsts) {
   bool Transformed = false;
 
+  auto &InnerSimpleReductions = LIL.getInnerSimpleReductions();
+  if (EnableUndoSimpleReduction && InnerSimpleReductions.size() == 1)
+    undoSimpleReduction();
+
   if (InnerLoop->getSubLoops().empty()) {
     BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader();
     LLVM_DEBUG(dbgs() << "Splitting the inner loop latch\n");
diff --git a/llvm/test/Transforms/LoopInterchange/simple-reduction.ll b/llvm/test/Transforms/LoopInterchange/simple-reduction.ll
new file mode 100644
index 0000000000000..9a4393f827a36
--- /dev/null
+++ b/llvm/test/Transforms/LoopInterchange/simple-reduction.ll
@@ -0,0 +1,86 @@
+; NOTE: Support simple reduction in the inner loop by undoing the simple reduction.
+; RUN: opt < %s -passes="loop(loop-interchange),dce"  -undo-simple-reduction -loop-interchange-profitabilities=ignore -S | FileCheck %s
+
+; for (int i = 0; i < n; i++) {
+;   s[i] = 0;
+;   for (int j = 0; j < n; j++)
+;     s[i] = s[i] + a[j][i] * b[j][i];
+; }
+
+define void @func(ptr noalias noundef readonly captures(none) %a, ptr noalias noundef readonly captures(none) %b, ptr noalias noundef writeonly captures(none) %s, i64  noundef %n) {
+; CHECK-LABEL: define void @func(ptr noalias noundef readonly captures(none) %a, ptr noalias noundef readonly captures(none) %b, ptr noalias noundef writeonly captures(none) %s, i64  noundef %n) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i64 [[N:%.*]], 0
+; CHECK-NEXT:    br i1 [[CMP]], label [[INNERLOOP_PREHEADER:%.*]], label [[EXIT:%.*]]
+; CHECK:       outerloop_header.preheader:
+; CHECK-NEXT:    br label [[OUTERLOOP_HEADER:%.*]]
+; CHECK:       outerloop_header:
+; CHECK-NEXT:    [[INDEX_I:%.*]] = phi i64 [ [[I_NEXT:%.*]], [[OUTERLOOP_LATCH:%.*]] ], [ 0, [[OUTERLOOPHEADER_PREHEADER:%.*]] ]
+; CHECK-NEXT:    [[ADDR_S:%.*]] = getelementptr inbounds nuw double, ptr %s, i64 [[INDEX_I]]
+; CHECK-NEXT:    [[ADDR_A:%.*]] = getelementptr inbounds nuw [100 x double], ptr %a, i64 0, i64 [[INDEX_I]]
+; CHECK-NEXT:    [[ADDR_B:%.*]] = getelementptr inbounds nuw [100 x double], ptr %b, i64 0, i64 [[INDEX_I]]
+; CHECK-NEXT:    br label [[INNERLOOP_SPLIT1:%.*]]
+; CHECK:       innerloop.preheader:
+; CHECK-NEXT:    br label [[INNERLOOP:%.*]]
+; CHECK:       innerloop:
+; CHECK-NEXT:    [[INDEX_J:%.*]] = phi i64 [ [[J_NEXT:%.*]], [[INNERLOOP_SPLIT:%.*]] ], [ 0, [[INNERLOOP_PREHEADER:%.*]] ]
+; CHECK-NEXT:    br label [[OUTERLOOPHEADER_PREHEADER:%.*]]
+; CHECK:       innerloop.split1:
+; CHECK-NEXT:    [[S:%.*]] = load double, ptr [[ADDR_S]], align 8
+; CHECK-NEXT:    [[FIRSTITER:%.*]] = icmp ne i64 [[INDEX_J]], 0
+; CHECK-NEXT:    [[NEW_VAR:%.*]] = select i1 [[FIRSTITER]], double [[S]], double 0.000000e+00
+; CHECK-NEXT:    [[ADDR_A_J_I:%.*]] = getelementptr inbounds nuw [100 x double], ptr [[ADDR_A]], i64 [[INDEX_J]]
+; CHECK-NEXT:    [[A_J_I:%.*]] = load double, ptr [[ADDR_A_J_I]], align 8
+; CHECK-NEXT:    [[ADDR_B_J_I:%.*]] = getelementptr inbounds nuw [100 x double], ptr [[ADDR_B]], i64 [[INDEX_J]]
+; CHECK-NEXT:    [[B_J_I:%.*]] = load double, ptr [[ADDR_B_J_I]], align 8
+; CHECK-NEXT:    [[MUL:%.*]] = fmul fast double [[B_J_I]], [[A_J_I]]
+; CHECK-NEXT:    [[ADD:%.*]] = fadd fast double [[MUL]], [[NEW_VAR]]
+; CHECK-NEXT:    store double [[ADD]], ptr [[ADDR_S]], align 8
+; CHECK-NEXT:    br label [[OUTERLOOP_LATCH:%.*]]
+; CHECK:       innerloop.split:
+; CHECK-NEXT:    [[J_NEXT:%.*]] = add nuw nsw i64 [[INDEX_J]], 1
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp eq i64 [[J_NEXT]], [[N]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[EXIT_LOOPEXIT:%.*]], label [[INNERLOOP]]
+; CHECK:       outerloop_latch:
+; CHECK-NEXT:    [[I_NEXT]] = add nuw nsw i64 [[INDEX_I]], 1
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i64 [[I_NEXT]], [[N]]
+; CHECK-NEXT:    br i1 [[CMP2]], label [[INNERLOOP_SPLIT:%.*]], label [[OUTERLOOP_HEADER]]
+; CHECK:       exit.loopexit:
+; CHECK-NEXT:    br label [[EXIT:%.*]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %cmp = icmp sgt i64 %n, 0
+  br i1 %cmp, label %outerloop_header, label %exit
+
+outerloop_header:                                      
+  %index_i = phi i64 [ 0, %entry ], [ %index_i.next, %outerloop_latch ]
+  %addr_s = getelementptr inbounds nuw double, ptr %s, i64 %index_i
+  %invariant.gep.us = getelementptr inbounds nuw [100 x double], ptr %a, i64 0, i64 %index_i
+  %invariant.gep32.us = getelementptr inbounds nuw [100 x double], ptr %b, i64 0, i64 %index_i
+  br label %innerloop
+
+innerloop:                                     
+  %index_j = phi i64 [ 0, %outerloop_header ], [ %index_j.next, %innerloop ]
+  %reduction = phi double [ 0.000000e+00, %outerloop_header ], [ %add, %innerloop ]
+  %addr_a_j_i = getelementptr inbounds nuw [100 x double], ptr %invariant.gep.us, i64 %index_j
+  %0 = load double, ptr %addr_a_j_i, align 8
+  %addr_b_j_i = getelementptr inbounds nuw [100 x double], ptr %invariant.gep32.us, i64 %index_j
+  %1 = load double, ptr %addr_b_j_i, align 8
+  %mul = fmul fast double %1, %0
+  %add = fadd fast double %mul, %reduction
+  %index_j.next = add nuw nsw i64 %index_j, 1
+  %cond1 = icmp eq i64 %index_j.next, %n
+  br i1 %cond1, label %outerloop_latch, label %innerloop
+
+outerloop_latch:         
+  %lcssa = phi double [ %add, %innerloop ]
+  store double %lcssa, ptr %addr_s, align 8
+  %index_i.next = add nuw nsw i64 %index_i, 1
+  %cond2 = icmp eq i64 %index_i.next, %n
+  br i1 %cond2, label %exit, label %outerloop_header
+
+exit:                                 
+  ret void
+}



More information about the llvm-commits mailing list