[llvm] [VectorCombine] Fold binary op of reductions. (PR #121567)

David Green via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 20 08:21:21 PST 2025


================
@@ -1182,6 +1183,135 @@ bool VectorCombine::foldExtractedCmps(Instruction &I) {
   return true;
 }
 
+static void analyzeCostOfVecReduction(const IntrinsicInst &II,
+                                      TTI::TargetCostKind CostKind,
+                                      const TargetTransformInfo &TTI,
+                                      InstructionCost &CostBeforeReduction,
+                                      InstructionCost &CostAfterReduction) {
+  using namespace llvm::PatternMatch;
+  Instruction *Op0, *Op1;
+  Instruction *RedOp = dyn_cast<Instruction>(II.getOperand(0));
+  VectorType *VecRedTy = cast<VectorType>(II.getOperand(0)->getType());
+  unsigned ReductionOpc =
+      getArithmeticReductionInstruction(II.getIntrinsicID());
+  if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value()))) {
+    bool IsUnsigned = isa<ZExtInst>(RedOp);
+    VectorType *ExtType =
+        cast<VectorType>(RedOp->getOperand(0)->getType());
+
+    CostBeforeReduction =
+        TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, ExtType,
+                             TTI::CastContextHint::None, CostKind, RedOp);
+    CostAfterReduction =
+        TTI.getExtendedReductionCost(ReductionOpc, IsUnsigned, II.getType(),
+                                     ExtType, FastMathFlags(), CostKind);
+    return;
+  }
+  if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
+      match(RedOp,
+            m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) &&
+      match(Op0, m_ZExtOrSExt(m_Value())) &&
+      Op0->getOpcode() == Op1->getOpcode() &&
+      Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
+      (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
+    // Matched reduce.add(ext(mul(ext(A), ext(B)))
+    bool IsUnsigned = isa<ZExtInst>(Op0);
+    VectorType *ExtType =
+        cast<VectorType>(Op0->getOperand(0)->getType());
+    VectorType *MulType = VectorType::get(Op0->getType(), VecRedTy);
----------------
davemgreen wrote:

I think it was a note about a previous version that no longer applies

https://github.com/llvm/llvm-project/pull/121567


More information about the llvm-commits mailing list