[llvm] 1d9b322 - [VPlan] Implement VPWidenSelectRecipe::computeCost.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 21 19:10:57 PDT 2024


Author: Florian Hahn
Date: 2024-10-22T03:10:04+01:00
New Revision: 1d9b3222f3de7bad4ef27b7e4d7798f840097380

URL: https://github.com/llvm/llvm-project/commit/1d9b3222f3de7bad4ef27b7e4d7798f840097380
DIFF: https://github.com/llvm/llvm-project/commit/1d9b3222f3de7bad4ef27b7e4d7798f840097380.diff

LOG: [VPlan] Implement VPWidenSelectRecipe::computeCost.

Implement VPlan-based cost computation for VPWidenSelectRecipe.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/VPlan.cpp
    llvm/lib/Transforms/Vectorize/VPlan.h
    llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
    llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index c1b97791331bcf..6ab8fb45c351b4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1701,3 +1701,11 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) {
       Plan->print(O);
 }
 #endif
+
+TargetTransformInfo::OperandValueInfo
+VPCostContext::getOperandInfo(VPValue *V) const {
+  if (!V->isLiveIn())
+    return {};
+
+  return TTI::getOperandInfo(V->getLiveInIRValue());
+}

diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 59a084401cc9bf..8dff800a0b2224 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -38,6 +38,7 @@
 #include "llvm/Analysis/DomTreeUpdater.h"
 #include "llvm/Analysis/IVDescriptors.h"
 #include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/DebugLoc.h"
 #include "llvm/IR/FMF.h"
@@ -738,6 +739,9 @@ struct VPCostContext {
   /// Return true if the cost for \p UI shouldn't be computed, e.g. because it
   /// has already been pre-computed.
   bool skipCostComputation(Instruction *UI, bool IsVector) const;
+
+  /// Returns the OperandInfo for \p V, if it is a live-in.
+  TargetTransformInfo::OperandValueInfo getOperandInfo(VPValue *V) const;
 };
 
 /// VPRecipeBase is a base class modeling a sequence of one or more output IR
@@ -1844,6 +1848,10 @@ struct VPWidenSelectRecipe : public VPSingleDefRecipe {
   /// Produce a widened version of the select instruction.
   void execute(VPTransformState &State) override;
 
+  /// Return the cost of this VPWidenSelectRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print the recipe.
   void print(raw_ostream &O, const Twine &Indent,

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 1b05afd6b117a5..b8d69f4cd394d0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -75,8 +75,8 @@ template <unsigned BitWidth = 0> struct specific_intval {
     if (!CI)
       return false;
 
-    assert((BitWidth == 0 || CI->getBitWidth() == BitWidth) &&
-           "Trying the match constant with unexpected bitwidth.");
+    if (BitWidth != 0 && CI->getBitWidth() != BitWidth)
+      return false;
     return APInt::isSameValue(CI->getValue(), Val);
   }
 };
@@ -87,6 +87,8 @@ inline specific_intval<0> m_SpecificInt(uint64_t V) {
 
 inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); }
 
+inline specific_intval<1> m_True() { return specific_intval<1>(APInt(64, 1)); }
+
 /// Matching combinators
 template <typename LTy, typename RTy> struct match_combine_or {
   LTy L;
@@ -122,7 +124,8 @@ struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
     auto *DefR = dyn_cast<RecipeTy>(R);
     // Check for recipes that do not have opcodes.
     if constexpr (std::is_same<RecipeTy, VPScalarIVStepsRecipe>::value ||
-                  std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value)
+                  std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value ||
+                  std::is_same<RecipeTy, VPWidenSelectRecipe>::value)
       return DefR;
     else
       return DefR && DefR->getOpcode() == Opcode;
@@ -322,10 +325,34 @@ m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
   return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
 }
 
+template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode>
+using AllTernaryRecipe_match =
+    Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, Opcode, false,
+                 VPReplicateRecipe, VPInstruction, VPWidenSelectRecipe>;
+
+template <typename Op0_t, typename Op1_t, typename Op2_t>
+inline AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>
+m_Select(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
+  return AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>(
+      {Op0, Op1, Op2});
+}
+
 template <typename Op0_t, typename Op1_t>
-inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>
+inline match_combine_or<
+    BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>,
+    AllTernaryRecipe_match<Op0_t, Op1_t, specific_intval<1>,
+                           Instruction::Select>>
 m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
-  return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1);
+  return m_CombineOr(
+      m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1),
+      m_Select(Op0, Op1, m_False()));
+}
+
+template <typename Op0_t, typename Op1_t>
+inline AllTernaryRecipe_match<Op0_t, specific_intval<1>, Op1_t,
+                              Instruction::Select>
+m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) {
+  return m_Select(Op0, m_True(), Op1);
 }
 
 using VPCanonicalIVPHI_match =
@@ -344,7 +371,6 @@ inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0,
                                                            const Op1_t &Op1) {
   return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1);
 }
-
 } // namespace VPlanPatternMatch
 } // namespace llvm
 

diff  --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 41f13cc2d9a978..945874fd2c1ebb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -13,6 +13,7 @@
 
 #include "VPlan.h"
 #include "VPlanAnalysis.h"
+#include "VPlanPatternMatch.h"
 #include "VPlanUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
@@ -23,6 +24,7 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/IR/VectorBuilder.h"
@@ -1200,6 +1202,46 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) {
   State.addMetadata(Sel, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
 }
 
+InstructionCost VPWidenSelectRecipe::computeCost(ElementCount VF,
+                                                 VPCostContext &Ctx) const {
+  SelectInst *SI = cast<SelectInst>(getUnderlyingValue());
+  bool ScalarCond = getOperand(0)->isDefinedOutsideLoopRegions();
+  Type *ScalarTy = Ctx.Types.inferScalarType(this);
+  Type *VectorTy = ToVectorTy(Ctx.Types.inferScalarType(this), VF);
+  TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+  VPValue *Op0, *Op1;
+  using namespace llvm::VPlanPatternMatch;
+  if (!ScalarCond && ScalarTy->getScalarSizeInBits() == 1 &&
+      (match(this, m_LogicalAnd(m_VPValue(Op0), m_VPValue(Op1))) ||
+       match(this, m_LogicalOr(m_VPValue(Op0), m_VPValue(Op1))))) {
+    // select x, y, false --> x & y
+    // select x, true, y --> x | y
+    const auto [Op1VK, Op1VP] = Ctx.getOperandInfo(Op0);
+    const auto [Op2VK, Op2VP] = Ctx.getOperandInfo(Op1);
+
+    SmallVector<const Value *, 2> Operands;
+    if (all_of(operands(),
+               [](VPValue *Op) { return Op->getUnderlyingValue(); }))
+      Operands.append(SI->op_begin(), SI->op_end());
+    bool IsLogicalOr = match(this, m_LogicalOr(m_VPValue(Op0), m_VPValue(Op1)));
+    return Ctx.TTI.getArithmeticInstrCost(
+        IsLogicalOr ? Instruction::Or : Instruction::And, VectorTy, CostKind,
+        {Op1VK, Op1VP}, {Op2VK, Op2VP}, Operands, SI);
+  }
+
+  Type *CondTy = Ctx.Types.inferScalarType(getOperand(0));
+  if (!ScalarCond)
+    CondTy = VectorType::get(CondTy, VF);
+
+  CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
+  if (auto *Cmp = dyn_cast<CmpInst>(SI->getCondition()))
+    Pred = Cmp->getPredicate();
+  return Ctx.TTI.getCmpSelInstrCost(Instruction::Select, VectorTy, CondTy, Pred,
+                                    CostKind, {TTI::OK_AnyValue, TTI::OP_None},
+                                    {TTI::OK_AnyValue, TTI::OP_None}, SI);
+}
+
 VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
     const FastMathFlags &FMF) {
   AllowReassoc = FMF.allowReassoc();


        


More information about the llvm-commits mailing list