[llvm] [InstCombine] Combine interleaved PHI reduction chains. (PR #143878)

Ricardo Jesus via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 16 03:22:14 PDT 2025


================
@@ -996,6 +997,152 @@ Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) {
   return NewCI;
 }
 
+/// Try to fold reduction ops interleaved through two PHIs to a single PHI.
+///
+/// For example, combine:
+///   %phi1 = phi [init1, %BB1], [%op1, %BB2]
+///   %phi2 = phi [init2, %BB1], [%op2, %BB2]
+///   %op1 = binop %phi1, constant1
+///   %op2 = binop %phi2, constant2
+///   %rdx = binop %op1, %op2
+/// =>
+///   %phi_combined = phi [init_combined, %BB1], [%op_combined, %BB2]
+///   %rdx_combined = binop %phi_combined, constant_combined
+///
+/// For now, we require init1, init2, constant1 and constant2 to be constants.
+Instruction *InstCombinerImpl::foldPHIReduction(PHINode &PN) {
+  BinaryOperator *BO1;
+  Value *Start1;
+  Value *Step1;
+
+  // Find the first recurrence.
+  if (!PN.hasOneUse() || !matchSimpleRecurrence(&PN, BO1, Start1, Step1))
+    return nullptr;
+
+  // Ensure BO1 has two uses (PN and the reduction op) and can be reassociated.
+  if (!BO1->hasNUses(2) || !BO1->isAssociative())
+    return nullptr;
+
+  // Convert Start1 and Step1 to constants.
+  auto *Init1 = dyn_cast<Constant>(Start1);
+  auto *C1 = dyn_cast<Constant>(Step1);
+  if (!Init1 || !C1)
+    return nullptr;
+
+  // Find the reduction operation.
+  auto Opc = BO1->getOpcode();
+  BinaryOperator *Rdx = nullptr;
+  for (User *U : BO1->users())
+    if (U != &PN) {
+      Rdx = dyn_cast<BinaryOperator>(U);
+      break;
+    }
+  if (!Rdx || Rdx->getOpcode() != Opc || !Rdx->isAssociative())
+    return nullptr;
+
+  // Find the interleaved binop.
+  assert((Rdx->getOperand(0) == BO1 || Rdx->getOperand(1) == BO1) &&
+         "Unexpected operand!");
+  auto *BO2 =
+      dyn_cast<BinaryOperator>(Rdx->getOperand(Rdx->getOperand(0) == BO1));
+  if (!BO2 || !BO2->hasNUses(2) || !BO2->isAssociative() ||
+      BO2->getOpcode() != Opc || BO2->getParent() != BO1->getParent())
+    return nullptr;
+
+  // Find the interleaved PHI and recurrence constants.
+  PHINode *PN2;
+  Value *Start2;
+  Value *Step2;
+  if (!matchSimpleRecurrence(BO2, PN2, Start2, Step2) || !PN2->hasOneUse() ||
+      PN2->getParent() != PN.getParent())
+    return nullptr;
+
+  assert(PN2->getNumIncomingValues() == PN.getNumIncomingValues() &&
+         "Expected PHIs with the same number of incoming values!");
+
+  // Convert Start2 and Step2 to constants.
+  auto *Init2 = dyn_cast<Constant>(Start2);
+  auto *C2 = dyn_cast<Constant>(Step2);
+  if (!Init2 || !C2)
+    return nullptr;
+
+  assert(BO1->isCommutative() && BO2->isCommutative() && Rdx->isCommutative() &&
+         "Expected commutative instructions!");
+
+  // If we've got this far, we can transform:
+  //   pn = phi [init1; op1]
+  //   pn2 = phi [init2; op2]
+  //   op1 = binop (pn, c1)
+  //   op2 = binop (pn2, c2)
+  //   rdx = binop (op1, op2)
+  // Into:
+  //   pn = phi [binop (init1, init2); rdx]
+  //   rdx = binop (pn, binop (c1, c2))
+
+  // Attempt to fold the constants.
+  auto *Init = llvm::ConstantFoldBinaryInstruction(Opc, Init1, Init2);
+  auto *C = llvm::ConstantFoldBinaryInstruction(Opc, C1, C2);
+  if (!Init || !C)
+    return nullptr;
+
+  LLVM_DEBUG(dbgs() << "  Combining " << PN << "\n            " << *BO1
+                    << "\n       with " << *PN2 << "\n            " << *BO2
+                    << '\n');
+  ++NumPHIsInterleaved;
+
+  // Create the new PHI.
+  auto *NewPN = PHINode::Create(PN.getType(), PN.getNumIncomingValues());
+
+  // Create the new binary op.
+  auto *NewOp = BinaryOperator::Create(Opc, NewPN, C);
+  if (Opc == Instruction::FAdd || Opc == Instruction::FMul) {
+    // Intersect FMF flags for FADD and FMUL.
+    FastMathFlags Intersect = BO1->getFastMathFlags() &
+                              BO2->getFastMathFlags() & Rdx->getFastMathFlags();
+    NewOp->setFastMathFlags(Intersect);
+  } else {
+    OverflowTracking Flags;
+    Flags.AllKnownNonNegative = false;
+    Flags.AllKnownNonZero = false;
+    Flags.mergeFlags(*BO1);
+    Flags.mergeFlags(*BO2);
+    Flags.mergeFlags(*Rdx);
+    Flags.applyFlags(*NewOp);
+  }
+  InsertNewInstWith(NewOp, BO1->getIterator());
+  replaceInstUsesWith(*Rdx, NewOp);
+
+  for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) {
+    auto *V = PN.getIncomingValue(I);
+    auto *BB = PN.getIncomingBlock(I);
+    if (V == Init1) {
+      assert(((PN2->getIncomingValue(0) == Init2 &&
+               PN2->getIncomingBlock(0) == BB) ||
+              (PN2->getIncomingValue(1) == Init2 &&
+               PN2->getIncomingBlock(1) == BB)) &&
+             "Invalid incoming block!");
+      NewPN->addIncoming(Init, BB);
+    } else if (V == BO1) {
+      assert(((PN2->getIncomingValue(0) == BO2 &&
+               PN2->getIncomingBlock(0) == BB) ||
+              (PN2->getIncomingValue(1) == BO2 &&
+               PN2->getIncomingBlock(1) == BB)) &&
+             "Invalid incoming block!");
+      NewPN->addIncoming(NewOp, BB);
+    } else
+      llvm_unreachable("Unexpected incoming value!");
+  }
+
+  // Remove dead instructions. BO1/2 are replaced with poison to clean up their
+  // uses.
+  eraseInstFromFunction(*Rdx);
+  eraseInstFromFunction(*replaceInstUsesWith(*BO1, BO1));
+  eraseInstFromFunction(*replaceInstUsesWith(*BO2, BO2));
----------------
rj-jesus wrote:

Thanks, done.

https://github.com/llvm/llvm-project/pull/143878


More information about the llvm-commits mailing list