[llvm] [LV] Add support for partial reductions without a binary op (PR #133922)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 2 01:13:29 PDT 2025


================
@@ -296,49 +296,70 @@ bool VPRecipeBase::isScalarCast() const {
 InstructionCost
 VPPartialReductionRecipe::computeCost(ElementCount VF,
                                       VPCostContext &Ctx) const {
-  std::optional<unsigned> Opcode = std::nullopt;
-  VPValue *BinOp = getOperand(1);
+  std::optional<unsigned> Opcode;
+  VPValue *Op = getOperand(0);
+  VPRecipeBase *OpR = Op->getDefiningRecipe();
 
-  // If the partial reduction is predicated, a select will be operand 0 rather
-  // than the binary op
+  // If the partial reduction is predicated, a select will be operand 0
   using namespace llvm::VPlanPatternMatch;
-  if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
-    BinOp = BinOp->getDefiningRecipe()->getOperand(1);
-
-  // If BinOp is a negation, use the side effect of match to assign the actual
-  // binary operation to BinOp
-  match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
-  VPRecipeBase *BinOpR = BinOp->getDefiningRecipe();
-
-  if (auto *WidenR = dyn_cast<VPWidenRecipe>(BinOpR))
-    Opcode = std::make_optional(WidenR->getOpcode());
-
-  VPRecipeBase *ExtAR = BinOpR->getOperand(0)->getDefiningRecipe();
-  VPRecipeBase *ExtBR = BinOpR->getOperand(1)->getDefiningRecipe();
+  if (match(getOperand(1), m_Select(m_VPValue(), m_VPValue(Op), m_VPValue()))) {
+    OpR = Op->getDefiningRecipe();
+  }
 
-  auto *PhiType = Ctx.Types.inferScalarType(getOperand(1));
-  auto *InputTypeA = Ctx.Types.inferScalarType(ExtAR ? ExtAR->getOperand(0)
-                                                     : BinOpR->getOperand(0));
-  auto *InputTypeB = Ctx.Types.inferScalarType(ExtBR ? ExtBR->getOperand(0)
-                                                     : BinOpR->getOperand(1));
+  Type *InputTypeA = nullptr, *InputTypeB = nullptr;
+  TTI::PartialReductionExtendKind ExtAType = TTI::PR_None,
+                                  ExtBType = TTI::PR_None;
 
   auto GetExtendKind = [](VPRecipeBase *R) {
-    // The extend could come from outside the plan.
     if (!R)
-      return TargetTransformInfo::PR_None;
+      return TTI::PR_None;
     auto *WidenCastR = dyn_cast<VPWidenCastRecipe>(R);
     if (!WidenCastR)
-      return TargetTransformInfo::PR_None;
+      return TTI::PR_None;
     if (WidenCastR->getOpcode() == Instruction::CastOps::ZExt)
-      return TargetTransformInfo::PR_ZeroExtend;
+      return TTI::PR_ZeroExtend;
     if (WidenCastR->getOpcode() == Instruction::CastOps::SExt)
-      return TargetTransformInfo::PR_SignExtend;
-    return TargetTransformInfo::PR_None;
+      return TTI::PR_SignExtend;
+    return TTI::PR_None;
   };
 
-  return Ctx.TTI.getPartialReductionCost(
-      getOpcode(), InputTypeA, InputTypeB, PhiType, VF, GetExtendKind(ExtAR),
-      GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
+  // Pick out opcode, type/ext information and use sub side effects from a widen
+  // recipe.
+  auto HandleWiden = [&](VPWidenRecipe *Widen) {
+    if (match(Widen,
----------------
fhahn wrote:

On top of https://github.com/llvm/llvm-project/pull/146073, would we still need to detect the various patterns here or would/could this be replaced by different ExpressionTypes?

https://github.com/llvm/llvm-project/pull/133922


More information about the llvm-commits mailing list