[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 9 02:05:49 PST 2024


================
@@ -2192,30 +2193,146 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  // TODO: Support any-of and in-loop reductions.
+  // TODO: Support any-of reductions.
   assert(
       (!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
        ForceTargetInstructionCost.getNumOccurrences() > 0) &&
       "Any-of reduction not implemented in VPlan-based cost model currently.");
-  assert(
-      (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
-       ForceTargetInstructionCost.getNumOccurrences() > 0) &&
-      "In-loop reduction not implemented in VPlan-based cost model currently.");
 
   assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
          "Inferred type and recurrence type mismatch.");
 
-  // Cost = Reduction cost + BinOp cost
-  InstructionCost Cost =
-      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  // Note that TTI should model the cost of moving result to the scalar register
+  // and the binOp cost in the getReductionCost().
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    return Cost + Ctx.TTI.getMinMaxReductionCost(
-                      Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+    return Ctx.TTI.getMinMaxReductionCost(Id, VectorTy,
+                                          RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  return Cost + Ctx.TTI.getArithmeticReductionCost(
-                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  return Ctx.TTI.getArithmeticReductionCost(
+      Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+}
+
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+                                       VPCostContext &Ctx) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+  Type *ElementTy = getResultType();
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  auto *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  // ExtendedReduction Cost
+  InstructionCost ExtendedRedCost =
+      Ctx.TTI.getExtendedReductionCost(Opcode, isZExt(), ElementTy, SrcVecTy,
+                                       RdxDesc.getFastMathFlags(), CostKind);
+
+  assert(ExtendedRedCost.isValid() && "VPExtendedReductionRecipe should not be "
+                                      "created if the cost is invalid.");
+
+  InstructionCost ReductionCost;
+  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+    Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+    ReductionCost = Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    ReductionCost = Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
+
+  // Extended cost
+  TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
+  // Arm TTI will use the underlying instruction to determine the cost.
+  InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
+      Opcode, VectorTy, SrcVecTy, CCH, TTI::TCK_RecipThroughput,
+      dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+
+  // Check if folding ext into ExtendedReduction is profitable.
+  if (ExtendedRedCost.isValid() &&
+      ExtendedRedCost < ExtendedCost + ReductionCost) {
+    return ExtendedRedCost;
+  }
+
+  return ExtendedCost + ReductionCost;
+}
+
+InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
+                                            VPCostContext &Ctx) const {
+  const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor();
+  Type *ElementTy = isExtended() ? RdxDesc.getRecurrenceType()
+                                 : Ctx.Types.inferScalarType(getVecOp0());
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  assert(Opcode == Instruction::Add &&
+         "Reduction opcode must be add in the VPMulAccRecipe.");
+  // MulAccReduction Cost
+  VectorType *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+  InstructionCost MulAccCost =
+      Ctx.TTI.getMulAccReductionCost(isZExt(), ElementTy, SrcVecTy, CostKind);
+
+  assert(MulAccCost.isValid() && "VPMulAccRecipe should not be "
+                                 "created if the cost is invalid.");
+
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost ReductionCost = Ctx.TTI.getArithmeticReductionCost(
+      Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+
+  // Extended cost
----------------
fhahn wrote:

Redundant comment

https://github.com/llvm/llvm-project/pull/113903


More information about the llvm-commits mailing list