[llvm] [VPlan] Impl VPlan-based pattern match for ExtendedRed and MulAccRed (NFCI) (PR #113903)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 8 13:29:00 PST 2024


================
@@ -2185,30 +2186,150 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
   unsigned Opcode = RdxDesc.getOpcode();
 
-  // TODO: Support any-of and in-loop reductions.
+  // TODO: Support any-of reductions.
   assert(
       (!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) ||
        ForceTargetInstructionCost.getNumOccurrences() > 0) &&
       "Any-of reduction not implemented in VPlan-based cost model currently.");
-  assert(
-      (!cast<VPReductionPHIRecipe>(getOperand(0))->isInLoop() ||
-       ForceTargetInstructionCost.getNumOccurrences() > 0) &&
-      "In-loop reduction not implemented in VPlan-based cost model currently.");
 
   assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() &&
          "Inferred type and recurrence type mismatch.");
 
-  // Cost = Reduction cost + BinOp cost
-  InstructionCost Cost =
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost BaseCost =
       Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
   if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
     Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
-    return Cost + Ctx.TTI.getMinMaxReductionCost(
-                      Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+    BaseCost += Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    BaseCost += Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
   }
 
-  return Cost + Ctx.TTI.getArithmeticReductionCost(
-                    Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  // Default cost.
+  return BaseCost;
+}
+
+InstructionCost
+VPExtendedReductionRecipe::computeCost(ElementCount VF,
+                                       VPCostContext &Ctx) const {
+  RecurKind RdxKind = RdxDesc.getRecurrenceKind();
+  Type *ElementTy = getResultType();
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost ReductionCost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) {
+    Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind);
+    ReductionCost += Ctx.TTI.getMinMaxReductionCost(
+        Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  } else {
+    ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+        Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+  }
+
+  // Extended cost
+  auto *SrcTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp()), VF));
+  auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+  TTI::CastContextHint CCH = computeCCH(getVecOp()->getDefiningRecipe(), VF);
+  // Arm TTI will use the underlying instruction to determine the cost.
+  InstructionCost ExtendedCost = Ctx.TTI.getCastInstrCost(
+      Opcode, DestTy, SrcTy, CCH, TTI::TCK_RecipThroughput,
+      dyn_cast_if_present<Instruction>(getUnderlyingValue()));
+
+  // ExtendedReduction Cost
+  InstructionCost ExtendedRedCost = Ctx.TTI.getExtendedReductionCost(
+      Opcode, IsZExt, ElementTy, SrcTy, RdxDesc.getFastMathFlags(), CostKind);
+  // Check if folding ext into ExtendedReduction is profitable.
+  if (ExtendedRedCost.isValid() &&
+      ExtendedRedCost < ExtendedCost + ReductionCost) {
+    return ExtendedRedCost;
+  }
+  return ExtendedCost + ReductionCost;
+}
+
+InstructionCost VPMulAccRecipe::computeCost(ElementCount VF,
+                                            VPCostContext &Ctx) const {
+  Type *ElementTy = IsExtended ? RdxDesc.getRecurrenceType()
+                               : Ctx.Types.inferScalarType(getVecOp0());
+  auto *VectorTy = cast<VectorType>(ToVectorTy(ElementTy, VF));
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+  unsigned Opcode = RdxDesc.getOpcode();
+
+  assert(Opcode == Instruction::Add &&
+         "Reduction opcode must be add in the VPMulAccRecipe.");
+
+  // BaseCost = Reduction cost + BinOp cost
+  InstructionCost ReductionCost =
+      Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind);
+  ReductionCost += Ctx.TTI.getArithmeticReductionCost(
+      Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind);
+
+  // Extended cost
+  InstructionCost ExtendedCost = 0;
+  if (IsExtended) {
+    auto *SrcTy = cast<VectorType>(
+        ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+    auto *DestTy = cast<VectorType>(ToVectorTy(getResultType(), VF));
+    TTI::CastContextHint CCH0 =
+        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+    // Arm TTI will use the underlying instruction to determine the cost.
+    ExtendedCost = Ctx.TTI.getCastInstrCost(
+        ExtOp, DestTy, SrcTy, CCH0, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt0Instr()));
+    TTI::CastContextHint CCH1 =
+        computeCCH(getVecOp0()->getDefiningRecipe(), VF);
+    ExtendedCost += Ctx.TTI.getCastInstrCost(
+        ExtOp, DestTy, SrcTy, CCH1, TTI::TCK_RecipThroughput,
+        dyn_cast_if_present<Instruction>(getExt1Instr()));
+  }
+
+  // Mul cost
+  InstructionCost MulCost;
+  SmallVector<const Value *, 4> Operands;
+  Operands.append(MulInstr->value_op_begin(), MulInstr->value_op_end());
+  if (IsExtended)
+    MulCost = Ctx.TTI.getArithmeticInstrCost(
+        Instruction::Mul, VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        Operands, MulInstr, &Ctx.TLI);
+  else {
+    VPValue *RHS = getVecOp1();
+    // Certain instructions can be cheaper to vectorize if they have a constant
+    // second vector operand. One example of this are shifts on x86.
+    TargetTransformInfo::OperandValueInfo RHSInfo = {
+        TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None};
+    if (RHS->isLiveIn())
+      RHSInfo = Ctx.TTI.getOperandInfo(RHS->getLiveInIRValue());
+
+    if (RHSInfo.Kind == TargetTransformInfo::OK_AnyValue &&
+        RHS->isDefinedOutsideLoopRegions())
+      RHSInfo.Kind = TargetTransformInfo::OK_UniformValue;
+    MulCost = Ctx.TTI.getArithmeticInstrCost(
+        Instruction::Mul, VectorTy, CostKind,
+        {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+        RHSInfo, Operands, MulInstr, &Ctx.TLI);
+  }
+
+  // MulAccReduction Cost
+  VectorType *SrcVecTy =
+      cast<VectorType>(ToVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
+  InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
+      getExtOpcode() == Instruction::CastOps::ZExt, ElementTy, SrcVecTy,
+      CostKind);
+
+  // Check if folding ext into ExtendedReduction is profitable.
+  if (MulAccCost.isValid() &&
----------------
fhahn wrote:

Can we only create the MullAcc recipe if the cost is valid?

https://github.com/llvm/llvm-project/pull/113903


More information about the llvm-commits mailing list