[llvm] [VPlan] First step towards VPlan cost modeling (LegacyCM in CostCtx) (PR #92555)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Thu May 23 14:30:15 PDT 2024


================
@@ -730,6 +731,81 @@ void VPRegionBlock::execute(VPTransformState *State) {
   State->Instance.reset();
 }
 
+static InstructionCost computeCostForRecipe(VPRecipeBase *R, ElementCount VF,
+                                            VPCostContext &Ctx) {
+  Instruction *UI = nullptr;
+  if (auto *S = dyn_cast<VPSingleDefRecipe>(R))
+    UI = dyn_cast_or_null<Instruction>(S->getUnderlyingValue());
+  if (UI && Ctx.skipForCostComputation(UI))
+    return 0;
+
+  InstructionCost RecipeCost = R->computeCost(VF, Ctx);
+  if (ForceTargetInstructionCost.getNumOccurrences() > 0 &&
+      RecipeCost.isValid())
+    RecipeCost = InstructionCost(ForceTargetInstructionCost);
+
+  LLVM_DEBUG({
+    dbgs() << "Cost of " << RecipeCost << " for VF " << VF << ": ";
+    R->dump();
+  });
+  return RecipeCost;
+}
+
+InstructionCost VPBasicBlock::computeCost(ElementCount VF, VPCostContext &Ctx) {
+  InstructionCost Cost = 0;
+  for (VPRecipeBase &R : *this)
+    Cost += computeCostForRecipe(&R, VF, Ctx);
+  return Cost;
+}
+
+InstructionCost VPRegionBlock::computeCost(ElementCount VF,
+                                           VPCostContext &Ctx) {
+  InstructionCost Cost = 0;
+  if (!isReplicator()) {
+    for (VPBlockBase *Block : vp_depth_first_shallow(getEntry()))
+      Cost += Block->computeCost(VF, Ctx);
+    return Cost;
+  }
+
+  using namespace llvm::VPlanPatternMatch;
+  assert(isReplicator() && "can only compute cost for a replicator region");
+  VPBasicBlock *Then = cast<VPBasicBlock>(getEntry()->getSuccessors()[0]);
+  for (VPRecipeBase &R : *Then)
+    Cost += computeCostForRecipe(&R, VF, Ctx);
+
+  // Note the cost estimates below closely match the current legacy cost model.
+  auto *BOM = cast<VPBranchOnMaskRecipe>(&getEntryBasicBlock()->front());
+  VPValue *Cond = BOM->getOperand(0);
+
+  // Check if Cond is a uniform compare or a header mask.
+  VPValue *Op;
+  bool IsHeaderMaskOrUniformCond =
+      vputils::isUniformCompare(Cond) ||
+      match(Cond, m_ActiveLaneMask(m_VPValue(), m_VPValue())) ||
+      (match(Cond, m_Binary<Instruction::ICmp>(m_VPValue(), m_VPValue(Op))) &&
+       Op == getPlan()->getOrCreateBackedgeTakenCount()) ||
+      isa<VPActiveLaneMaskPHIRecipe>(Cond);
----------------
fhahn wrote:

moved, thanks!

https://github.com/llvm/llvm-project/pull/92555


More information about the llvm-commits mailing list