[llvm] [LoopInterchange] Support inner-loop simple reductions via UndoSimpleReduction (PR #172970)

Ryotaro Kasuga via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 24 06:14:12 PST 2025


================
@@ -910,79 +1035,183 @@ findInnerReductionPhi(Loop *L, Value *V,
     if (PHINode *PHI = dyn_cast<PHINode>(User)) {
       if (PHI->getNumIncomingValues() == 1)
         continue;
-      RecurrenceDescriptor RD;
-      if (RecurrenceDescriptor::isReductionPHI(PHI, L, RD)) {
-        // Detect floating point reduction only when it can be reordered.
-        if (RD.getExactFPMathInst() != nullptr)
-          return nullptr;
-
-        RecurKind RK = RD.getRecurrenceKind();
-        switch (RK) {
-        case RecurKind::Or:
-        case RecurKind::And:
-        case RecurKind::Xor:
-        case RecurKind::SMin:
-        case RecurKind::SMax:
-        case RecurKind::UMin:
-        case RecurKind::UMax:
-        case RecurKind::FAdd:
-        case RecurKind::FMul:
-        case RecurKind::FMin:
-        case RecurKind::FMax:
-        case RecurKind::FMinimum:
-        case RecurKind::FMaximum:
-        case RecurKind::FMinimumNum:
-        case RecurKind::FMaximumNum:
-        case RecurKind::FMulAdd:
-        case RecurKind::AnyOf:
-          return PHI;
-
-        // Change the order of integer addition/multiplication may change the
-        // semantics. Consider the following case:
-        //
-        //  int A[2][2] = {{ INT_MAX, INT_MAX }, { INT_MIN, INT_MIN }};
-        //  int sum = 0;
-        //  for (int i = 0; i < 2; i++)
-        //    for (int j = 0; j < 2; j++)
-        //      sum += A[j][i];
-        //
-        // If the above loops are exchanged, the addition will cause an
-        // overflow. To prevent this, we must drop the nuw/nsw flags from the
-        // addition/multiplication instructions when we actually exchanges the
-        // loops.
-        case RecurKind::Add:
-        case RecurKind::Mul: {
-          unsigned OpCode = RecurrenceDescriptor::getOpcode(RK);
-          SmallVector<Instruction *, 4> Ops = RD.getReductionOpChain(PHI, L);
-
-          // Bail out when we fail to collect reduction instructions chain.
-          if (Ops.empty())
-            return nullptr;
-
-          for (Instruction *I : Ops) {
-            assert(I->getOpcode() == OpCode &&
-                   "Expected the instruction to be the reduction operation");
-            (void)OpCode;
-
-            // If the instruction has nuw/nsw flags, we must drop them when the
-            // transformation is actually performed.
-            if (I->hasNoSignedWrap() || I->hasNoUnsignedWrap())
-              HasNoWrapInsts.push_back(I);
-          }
-          return PHI;
-        }
 
-        default:
-          return nullptr;
-        }
-      }
-      return nullptr;
+      if (CheckReductionKind(L, PHI, HasNoWrapInsts))
+        return PHI;
+      else
+        return nullptr;
     }
   }
 
   return nullptr;
 }
 
+static PHINode *getCounterFromInc(Value *IncV, Loop *L) {
+  Instruction *IncI = dyn_cast<Instruction>(IncV);
+  if (!IncI)
+    return nullptr;
+
+  if (IncI->getOpcode() != Instruction::Add &&
+      IncI->getOpcode() != Instruction::Sub)
+    return nullptr;
+
+  PHINode *Phi = dyn_cast<PHINode>(IncI->getOperand(0));
+  if (Phi && Phi->getParent() == L->getHeader()) {
+    return Phi;
+  }
+
+  // Allow add/sub to be commuted.
+  Phi = dyn_cast<PHINode>(IncI->getOperand(1));
+  if (Phi && Phi->getParent() == L->getHeader()) {
+    return Phi;
+  }
+
+  return nullptr;
+}
+
+/// UndoSimpleReduction requires the first_iteration check, so look for
+/// the IV used for the loop exit condition
+static PHINode *findCounterIV(Loop *L) {
+
+  assert(L->getLoopLatch() && "Must be in simplified form");
+
+  BranchInst *BI = cast<BranchInst>(L->getLoopLatch()->getTerminator());
+  if (L->isLoopInvariant(BI->getCondition()))
+    return nullptr;
+
+  ICmpInst *Cond = dyn_cast<ICmpInst>(BI->getCondition());
+  if (!Cond)
+    return nullptr;
+
+  // Look for a loop invariant RHS
+  Value *LHS = Cond->getOperand(0);
+  Value *RHS = Cond->getOperand(1);
+  if (!L->isLoopInvariant(RHS)) {
+    if (!L->isLoopInvariant(LHS))
+      return nullptr;
+    std::swap(LHS, RHS);
+  }
+
+  // IndVar = phi[{InitialValue, preheader}, {StepInst, latch}]
+  // StepInst = IndVar + step
+  // case 1:
+  // cmp = IndVar < FinalValue
+  PHINode *Counter = dyn_cast<PHINode>(LHS);
+  // case 2:
+  // cmp = StepInst < FinalValue
+  if (!Counter)
+    Counter = getCounterFromInc(LHS, L);
+
+  return Counter;
+}
+
+/// Detect and record the simple reduction of the inner loop.
+///
+///    innerloop:
+///        Re = phi<0.0, Next>
+///        ReUser = Re op ...
+///        ...
+///        Next = ReUser op ...
+///    OuterLoopLatch:
+///        Lcssa = phi<Next>    ; lcssa phi
+///        store Lcssa, MemRef  ; LcssaStorer
+///
+bool LoopInterchangeLegality::findSimpleReduction(
+    Loop *L, PHINode *Phi, SmallVectorImpl<Instruction *> &HasNoWrapInsts) {
+
+  // Only support undo simple reduction if the loop nest to be interchanged is
+  // the innermostin two loops.
+  if (!L->isInnermost())
+    return false;
+
+  if (Phi->getNumIncomingValues() != 2)
+    return false;
+
+  Value *Init = Phi->getIncomingValueForBlock(L->getLoopPreheader());
+  Value *Next = Phi->getIncomingValueForBlock(L->getLoopLatch());
+
+  // So far only supports constant initial value.
+  if (!isa<Constant>(Init))
+    return false;
+
+  // The reduction result must live in the inner loop.
+  if (Instruction *I = dyn_cast<Instruction>(Next)) {
+    BasicBlock *BB = I->getParent();
+    if (!L->contains(BB))
+      return false;
+  }
+
+  // The reduction should have only one user.
+  if (!Phi->hasOneUser())
+    return false;
+  Instruction *ReUser = dyn_cast<Instruction>(Phi->getUniqueUndroppableUser());
+  if (!ReUser || !L->contains(ReUser->getParent()))
+    return false;
+
+  // Check the reduction operation.
+  if (!ReUser->isAssociative())
+    return false;
+
+  // Check the reduction kind.
+  if (ReUser != Next && !CheckReductionKind(L, Phi, HasNoWrapInsts))
+    return false;
----------------
kasuga-fj wrote:

Do we need to call `CheckReductionKind` here?

https://github.com/llvm/llvm-project/pull/172970


More information about the llvm-commits mailing list