[llvm] [VectorCombine] Fold binary op of reductions. (PR #121567)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 12 04:39:50 PST 2025
================
@@ -1182,6 +1183,135 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
return true;
}
+static void analyzeCostOfVecReduction(const IntrinsicInst &II,
+ TTI::TargetCostKind CostKind,
+ const TargetTransformInfo &TTI,
+ InstructionCost &CostBeforeReduction,
+ InstructionCost &CostAfterReduction) {
+ using namespace llvm::PatternMatch;
+ Instruction *Op0, *Op1;
+ Instruction *RedOp = dyn_cast<Instruction>(II.getOperand(0));
+ VectorType *VecRedTy = cast<VectorType>(II.getOperand(0)->getType());
+ unsigned ReductionOpc =
+ getArithmeticReductionInstruction(II.getIntrinsicID());
+ if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value()))) {
+ bool IsUnsigned = isa<ZExtInst>(RedOp);
+ VectorType *ExtType =
+ cast<VectorType>(RedOp->getOperand(0)->getType());
+
+ CostBeforeReduction =
+ TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, ExtType,
+ TTI::CastContextHint::None, CostKind, RedOp);
+ CostAfterReduction =
+ TTI.getExtendedReductionCost(ReductionOpc, IsUnsigned, II.getType(),
+ ExtType, FastMathFlags(), CostKind);
+ return;
+ }
+ if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
+ match(RedOp,
+ m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) &&
+ match(Op0, m_ZExtOrSExt(m_Value())) &&
+ Op0->getOpcode() == Op1->getOpcode() &&
+ Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
+ (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
+ // Matched reduce.add(ext(mul(ext(A), ext(B)))
+ bool IsUnsigned = isa<ZExtInst>(Op0);
+ VectorType *ExtType =
+ cast<VectorType>(Op0->getOperand(0)->getType());
+ VectorType *MulType = VectorType::get(Op0->getType(), VecRedTy);
+
+ InstructionCost ExtCost =
+ TTI.getCastInstrCost(Op0->getOpcode(), MulType, ExtType,
+ TTI::CastContextHint::None, CostKind, Op0);
+ InstructionCost MulCost =
+ TTI.getArithmeticInstrCost(Instruction::Mul, MulType, CostKind);
+ InstructionCost Ext2Cost =
+ TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, MulType,
+ TTI::CastContextHint::None, CostKind, RedOp);
+
+ CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
+ CostAfterReduction =
+ TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
+ return;
+ }
+ CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
+ std::nullopt, CostKind);
+ return;
+}
+
+bool VectorCombine::foldBinopOfReductions(Instruction &I) {
+ Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
+ Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
+ if (BinOpOpc == Instruction::Sub)
+ ReductionIID = Intrinsic::vector_reduce_add;
+ if (ReductionIID == Intrinsic::not_intrinsic)
+ return false;
+
+ auto checkIntrinsicAndGetItsArgument = [](Value *V,
+ Intrinsic::ID IID) -> Value * {
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(V);
+ if (!II)
+ return nullptr;
+ if (II->getIntrinsicID() == IID && II->hasOneUse())
+ return II->getArgOperand(0);
+ return nullptr;
+ };
+
+ Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
+ if (!V0)
+ return false;
+ Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
+ if (!V1)
+ return false;
+
+ VectorType *VTy = cast<VectorType>(V0->getType());
+ if (V1->getType() != VTy)
+ return false;
+ const IntrinsicInst &II0 = *cast<IntrinsicInst>(I.getOperand(0));
+ const IntrinsicInst &II1 = *cast<IntrinsicInst>(I.getOperand(1));
+ unsigned ReductionOpc =
+ getArithmeticReductionInstruction(II0.getIntrinsicID());
+
+ InstructionCost OldCost = 0;
+ InstructionCost NewCost = 0;
+ InstructionCost CostOfRedOperand0 = 0;
+ InstructionCost CostOfRed0 = 0;
+ InstructionCost CostOfRedOperand1 = 0;
+ InstructionCost CostOfRed1 = 0;
+ analyzeCostOfVecReduction(II0, CostKind, TTI, CostOfRedOperand0, CostOfRed0);
+ analyzeCostOfVecReduction(II1, CostKind, TTI, CostOfRedOperand1, CostOfRed1);
+ OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(&I, CostKind);
+ NewCost =
+ CostOfRedOperand0 + CostOfRedOperand1 +
+ TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind) +
+ TTI.getArithmeticReductionCost(ReductionOpc, VTy, std::nullopt, CostKind);
+ // TODO: remove this
----------------
RKSimon wrote:
cleanup TODO and dbg, we usually just do something like:
```
LLVM_DEBUG(dbgs() << "Found mergeable reductions: " << I
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
<< "\n");
```
https://github.com/llvm/llvm-project/pull/121567
More information about the llvm-commits
mailing list