[llvm] [VectorCombine] Fold binary op of reductions. (PR #121567)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 20 09:39:50 PST 2025
================
@@ -1242,6 +1243,121 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
return true;
}
+static void analyzeCostOfVecReduction(const IntrinsicInst &II,
+ TTI::TargetCostKind CostKind,
+ const TargetTransformInfo &TTI,
+ InstructionCost &CostBeforeReduction,
+ InstructionCost &CostAfterReduction) {
+ Instruction *Op0, *Op1;
+ auto *RedOp = dyn_cast<Instruction>(II.getOperand(0));
+ auto *VecRedTy = cast<VectorType>(II.getOperand(0)->getType());
+ unsigned ReductionOpc =
+ getArithmeticReductionInstruction(II.getIntrinsicID());
+ if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value()))) {
+ bool IsUnsigned = isa<ZExtInst>(RedOp);
+ auto *ExtType = cast<VectorType>(RedOp->getOperand(0)->getType());
+
+ CostBeforeReduction =
+ TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, ExtType,
+ TTI::CastContextHint::None, CostKind, RedOp);
+ CostAfterReduction =
+ TTI.getExtendedReductionCost(ReductionOpc, IsUnsigned, II.getType(),
+ ExtType, FastMathFlags(), CostKind);
+ return;
+ }
+ if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
+ match(RedOp,
+ m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) &&
+ match(Op0, m_ZExtOrSExt(m_Value())) &&
+ Op0->getOpcode() == Op1->getOpcode() &&
+ Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
+ (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
+ // Matched reduce.add(ext(mul(ext(A), ext(B)))
+ bool IsUnsigned = isa<ZExtInst>(Op0);
+ auto *ExtType = cast<VectorType>(Op0->getOperand(0)->getType());
+ VectorType *MulType = VectorType::get(Op0->getType(), VecRedTy);
+
+ InstructionCost ExtCost =
+ TTI.getCastInstrCost(Op0->getOpcode(), MulType, ExtType,
+ TTI::CastContextHint::None, CostKind, Op0);
+ InstructionCost MulCost =
+ TTI.getArithmeticInstrCost(Instruction::Mul, MulType, CostKind);
+ InstructionCost Ext2Cost =
+ TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, MulType,
+ TTI::CastContextHint::None, CostKind, RedOp);
+
+ CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
+ CostAfterReduction =
+ TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
+ return;
+ }
+ CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
+ std::nullopt, CostKind);
+ return;
+}
+
+bool VectorCombine::foldBinopOfReductions(Instruction &I) {
+ Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
+ Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
+ if (BinOpOpc == Instruction::Sub)
+ ReductionIID = Intrinsic::vector_reduce_add;
+ if (ReductionIID == Intrinsic::not_intrinsic)
+ return false;
+
+ auto checkIntrinsicAndGetItsArgument = [](Value *V,
+ Intrinsic::ID IID) -> Value * {
+ auto *II = dyn_cast<IntrinsicInst>(V);
+ if (!II)
+ return nullptr;
+ if (II->getIntrinsicID() == IID && II->hasOneUse())
+ return II->getArgOperand(0);
+ return nullptr;
+ };
+
+ Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
+ if (!V0)
+ return false;
+ Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
+ if (!V1)
+ return false;
+
+ auto *VTy = cast<VectorType>(V0->getType());
+ if (V1->getType() != VTy)
+ return false;
+ const auto &II0 = *cast<IntrinsicInst>(I.getOperand(0));
----------------
topperc wrote:
Can `checkIntrinsicAndGetItsArgument` return `IntrinsicInst*` instead of `Value*`?
https://github.com/llvm/llvm-project/pull/121567
More information about the llvm-commits
mailing list