[llvm] [InstCombine] Combine interleaved PHI reduction chains. (PR #143878)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 16 02:13:22 PDT 2025
================
@@ -996,6 +997,152 @@ Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) {
return NewCI;
}
+/// Try to fold reduction ops interleaved through two PHIs to a single PHI.
+///
+/// For example, combine:
+/// %phi1 = phi [init1, %BB1], [%op1, %BB2]
+/// %phi2 = phi [init2, %BB1], [%op2, %BB2]
+/// %op1 = binop %phi1, constant1
+/// %op2 = binop %phi2, constant2
+/// %rdx = binop %op1, %op2
+/// =>
+/// %phi_combined = phi [init_combined, %BB1], [%op_combined, %BB2]
+/// %rdx_combined = binop %phi_combined, constant_combined
+///
+/// For now, we require init1, init2, constant1 and constant2 to be constants.
+Instruction *InstCombinerImpl::foldPHIReduction(PHINode &PN) {
+ BinaryOperator *BO1;
+ Value *Start1;
+ Value *Step1;
+
+ // Find the first recurrence.
+ if (!PN.hasOneUse() || !matchSimpleRecurrence(&PN, BO1, Start1, Step1))
+ return nullptr;
+
+ // Ensure BO1 has two uses (PN and the reduction op) and can be reassociated.
+ if (!BO1->hasNUses(2) || !BO1->isAssociative())
+ return nullptr;
+
+ // Convert Start1 and Step1 to constants.
+ auto *Init1 = dyn_cast<Constant>(Start1);
+ auto *C1 = dyn_cast<Constant>(Step1);
+ if (!Init1 || !C1)
+ return nullptr;
+
+ // Find the reduction operation.
+ auto Opc = BO1->getOpcode();
+ BinaryOperator *Rdx = nullptr;
+ for (User *U : BO1->users())
+ if (U != &PN) {
+ Rdx = dyn_cast<BinaryOperator>(U);
+ break;
+ }
+ if (!Rdx || Rdx->getOpcode() != Opc || !Rdx->isAssociative())
+ return nullptr;
+
+ // Find the interleaved binop.
+ assert((Rdx->getOperand(0) == BO1 || Rdx->getOperand(1) == BO1) &&
+ "Unexpected operand!");
+ auto *BO2 =
+ dyn_cast<BinaryOperator>(Rdx->getOperand(Rdx->getOperand(0) == BO1));
+ if (!BO2 || !BO2->hasNUses(2) || !BO2->isAssociative() ||
+ BO2->getOpcode() != Opc || BO2->getParent() != BO1->getParent())
+ return nullptr;
+
+ // Find the interleaved PHI and recurrence constants.
+ PHINode *PN2;
+ Value *Start2;
+ Value *Step2;
+ if (!matchSimpleRecurrence(BO2, PN2, Start2, Step2) || !PN2->hasOneUse() ||
+ PN2->getParent() != PN.getParent())
+ return nullptr;
+
+ assert(PN2->getNumIncomingValues() == PN.getNumIncomingValues() &&
+ "Expected PHIs with the same number of incoming values!");
+
+ // Convert Start2 and Step2 to constants.
+ auto *Init2 = dyn_cast<Constant>(Start2);
+ auto *C2 = dyn_cast<Constant>(Step2);
+ if (!Init2 || !C2)
+ return nullptr;
+
+ assert(BO1->isCommutative() && BO2->isCommutative() && Rdx->isCommutative() &&
+ "Expected commutative instructions!");
+
+ // If we've got this far, we can transform:
+ // pn = phi [init1; op1]
+ // pn2 = phi [init2; op2]
+ // op1 = binop (pn, c1)
+ // op2 = binop (pn2, c2)
+ // rdx = binop (op1, op2)
+ // Into:
+ // pn = phi [binop (init1, init2); rdx]
+ // rdx = binop (pn, binop (c1, c2))
+
+ // Attempt to fold the constants.
+ auto *Init = llvm::ConstantFoldBinaryInstruction(Opc, Init1, Init2);
+ auto *C = llvm::ConstantFoldBinaryInstruction(Opc, C1, C2);
+ if (!Init || !C)
+ return nullptr;
+
+ LLVM_DEBUG(dbgs() << " Combining " << PN << "\n " << *BO1
+ << "\n with " << *PN2 << "\n " << *BO2
+ << '\n');
+ ++NumPHIsInterleaved;
+
+ // Create the new PHI.
+ auto *NewPN = PHINode::Create(PN.getType(), PN.getNumIncomingValues());
+
+ // Create the new binary op.
+ auto *NewOp = BinaryOperator::Create(Opc, NewPN, C);
+ if (Opc == Instruction::FAdd || Opc == Instruction::FMul) {
+ // Intersect FMF flags for FADD and FMUL.
+ FastMathFlags Intersect = BO1->getFastMathFlags() &
+ BO2->getFastMathFlags() & Rdx->getFastMathFlags();
+ NewOp->setFastMathFlags(Intersect);
+ } else {
+ OverflowTracking Flags;
+ Flags.AllKnownNonNegative = false;
+ Flags.AllKnownNonZero = false;
+ Flags.mergeFlags(*BO1);
+ Flags.mergeFlags(*BO2);
+ Flags.mergeFlags(*Rdx);
+ Flags.applyFlags(*NewOp);
+ }
+ InsertNewInstWith(NewOp, BO1->getIterator());
+ replaceInstUsesWith(*Rdx, NewOp);
+
+ for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) {
+ auto *V = PN.getIncomingValue(I);
+ auto *BB = PN.getIncomingBlock(I);
+ if (V == Init1) {
+ assert(((PN2->getIncomingValue(0) == Init2 &&
+ PN2->getIncomingBlock(0) == BB) ||
+ (PN2->getIncomingValue(1) == Init2 &&
+ PN2->getIncomingBlock(1) == BB)) &&
+ "Invalid incoming block!");
+ NewPN->addIncoming(Init, BB);
+ } else if (V == BO1) {
+ assert(((PN2->getIncomingValue(0) == BO2 &&
+ PN2->getIncomingBlock(0) == BB) ||
+ (PN2->getIncomingValue(1) == BO2 &&
+ PN2->getIncomingBlock(1) == BB)) &&
+ "Invalid incoming block!");
+ NewPN->addIncoming(NewOp, BB);
+ } else
+ llvm_unreachable("Unexpected incoming value!");
+ }
+
+ // Remove dead instructions. BO1/2 are replaced with poison to clean up their
+ // uses.
+ eraseInstFromFunction(*Rdx);
+ eraseInstFromFunction(*replaceInstUsesWith(*BO1, BO1));
+ eraseInstFromFunction(*replaceInstUsesWith(*BO2, BO2));
----------------
dtcxzyw wrote:
```suggestion
eraseInstFromFunction(*replaceInstUsesWith(*BO1, PoisonValue::get(BO1->getType())));
eraseInstFromFunction(*replaceInstUsesWith(*BO2, PoisonValue::get(BO2->getType())));
```
https://github.com/llvm/llvm-project/pull/143878
More information about the llvm-commits
mailing list