[llvm] 1d9b322 - [VPlan] Implement VPWidenSelectRecipe::computeCost.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 21 19:10:57 PDT 2024
Author: Florian Hahn
Date: 2024-10-22T03:10:04+01:00
New Revision: 1d9b3222f3de7bad4ef27b7e4d7798f840097380
URL: https://github.com/llvm/llvm-project/commit/1d9b3222f3de7bad4ef27b7e4d7798f840097380
DIFF: https://github.com/llvm/llvm-project/commit/1d9b3222f3de7bad4ef27b7e4d7798f840097380.diff
LOG: [VPlan] Implement VPWidenSelectRecipe::computeCost.
Implement VPlan-based cost computation for VPWidenSelectRecipe.
Added:
Modified:
llvm/lib/Transforms/Vectorize/VPlan.cpp
llvm/lib/Transforms/Vectorize/VPlan.h
llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp
index c1b97791331bcf..6ab8fb45c351b4 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp
@@ -1701,3 +1701,11 @@ void LoopVectorizationPlanner::printPlans(raw_ostream &O) {
Plan->print(O);
}
#endif
+
+TargetTransformInfo::OperandValueInfo
+VPCostContext::getOperandInfo(VPValue *V) const {
+ if (!V->isLiveIn())
+ return {};
+
+ return TTI::getOperandInfo(V->getLiveInIRValue());
+}
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 59a084401cc9bf..8dff800a0b2224 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -38,6 +38,7 @@
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/FMF.h"
@@ -738,6 +739,9 @@ struct VPCostContext {
/// Return true if the cost for \p UI shouldn't be computed, e.g. because it
/// has already been pre-computed.
bool skipCostComputation(Instruction *UI, bool IsVector) const;
+
+ /// Returns the OperandInfo for \p V, if it is a live-in.
+ TargetTransformInfo::OperandValueInfo getOperandInfo(VPValue *V) const;
};
/// VPRecipeBase is a base class modeling a sequence of one or more output IR
@@ -1844,6 +1848,10 @@ struct VPWidenSelectRecipe : public VPSingleDefRecipe {
/// Produce a widened version of the select instruction.
void execute(VPTransformState &State) override;
+ /// Return the cost of this VPWidenSelectRecipe.
+ InstructionCost computeCost(ElementCount VF,
+ VPCostContext &Ctx) const override;
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// Print the recipe.
void print(raw_ostream &O, const Twine &Indent,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index 1b05afd6b117a5..b8d69f4cd394d0 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -75,8 +75,8 @@ template <unsigned BitWidth = 0> struct specific_intval {
if (!CI)
return false;
- assert((BitWidth == 0 || CI->getBitWidth() == BitWidth) &&
- "Trying the match constant with unexpected bitwidth.");
+ if (BitWidth != 0 && CI->getBitWidth() != BitWidth)
+ return false;
return APInt::isSameValue(CI->getValue(), Val);
}
};
@@ -87,6 +87,8 @@ inline specific_intval<0> m_SpecificInt(uint64_t V) {
inline specific_intval<1> m_False() { return specific_intval<1>(APInt(64, 0)); }
+inline specific_intval<1> m_True() { return specific_intval<1>(APInt(64, 1)); }
+
/// Matching combinators
template <typename LTy, typename RTy> struct match_combine_or {
LTy L;
@@ -122,7 +124,8 @@ struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
auto *DefR = dyn_cast<RecipeTy>(R);
// Check for recipes that do not have opcodes.
if constexpr (std::is_same<RecipeTy, VPScalarIVStepsRecipe>::value ||
- std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value)
+ std::is_same<RecipeTy, VPCanonicalIVPHIRecipe>::value ||
+ std::is_same<RecipeTy, VPWidenSelectRecipe>::value)
return DefR;
else
return DefR && DefR->getOpcode() == Opcode;
@@ -322,10 +325,34 @@ m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
}
+template <typename Op0_t, typename Op1_t, typename Op2_t, unsigned Opcode>
+using AllTernaryRecipe_match =
+ Recipe_match<std::tuple<Op0_t, Op1_t, Op2_t>, Opcode, false,
+ VPReplicateRecipe, VPInstruction, VPWidenSelectRecipe>;
+
+template <typename Op0_t, typename Op1_t, typename Op2_t>
+inline AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>
+m_Select(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
+ return AllTernaryRecipe_match<Op0_t, Op1_t, Op2_t, Instruction::Select>(
+ {Op0, Op1, Op2});
+}
+
template <typename Op0_t, typename Op1_t>
-inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>
+inline match_combine_or<
+ BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::LogicalAnd>,
+ AllTernaryRecipe_match<Op0_t, Op1_t, specific_intval<1>,
+ Instruction::Select>>
m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) {
- return m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1);
+ return m_CombineOr(
+ m_VPInstruction<VPInstruction::LogicalAnd, Op0_t, Op1_t>(Op0, Op1),
+ m_Select(Op0, Op1, m_False()));
+}
+
+template <typename Op0_t, typename Op1_t>
+inline AllTernaryRecipe_match<Op0_t, specific_intval<1>, Op1_t,
+ Instruction::Select>
+m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) {
+ return m_Select(Op0, m_True(), Op1);
}
using VPCanonicalIVPHI_match =
@@ -344,7 +371,6 @@ inline VPScalarIVSteps_match<Op0_t, Op1_t> m_ScalarIVSteps(const Op0_t &Op0,
const Op1_t &Op1) {
return VPScalarIVSteps_match<Op0_t, Op1_t>(Op0, Op1);
}
-
} // namespace VPlanPatternMatch
} // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 41f13cc2d9a978..945874fd2c1ebb 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -13,6 +13,7 @@
#include "VPlan.h"
#include "VPlanAnalysis.h"
+#include "VPlanPatternMatch.h"
#include "VPlanUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -23,6 +24,7 @@
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/VectorBuilder.h"
@@ -1200,6 +1202,46 @@ void VPWidenSelectRecipe::execute(VPTransformState &State) {
State.addMetadata(Sel, dyn_cast_or_null<Instruction>(getUnderlyingValue()));
}
+InstructionCost VPWidenSelectRecipe::computeCost(ElementCount VF,
+ VPCostContext &Ctx) const {
+ SelectInst *SI = cast<SelectInst>(getUnderlyingValue());
+ bool ScalarCond = getOperand(0)->isDefinedOutsideLoopRegions();
+ Type *ScalarTy = Ctx.Types.inferScalarType(this);
+ Type *VectorTy = ToVectorTy(Ctx.Types.inferScalarType(this), VF);
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+
+ VPValue *Op0, *Op1;
+ using namespace llvm::VPlanPatternMatch;
+ if (!ScalarCond && ScalarTy->getScalarSizeInBits() == 1 &&
+ (match(this, m_LogicalAnd(m_VPValue(Op0), m_VPValue(Op1))) ||
+ match(this, m_LogicalOr(m_VPValue(Op0), m_VPValue(Op1))))) {
+ // select x, y, false --> x & y
+ // select x, true, y --> x | y
+ const auto [Op1VK, Op1VP] = Ctx.getOperandInfo(Op0);
+ const auto [Op2VK, Op2VP] = Ctx.getOperandInfo(Op1);
+
+ SmallVector<const Value *, 2> Operands;
+ if (all_of(operands(),
+ [](VPValue *Op) { return Op->getUnderlyingValue(); }))
+ Operands.append(SI->op_begin(), SI->op_end());
+ bool IsLogicalOr = match(this, m_LogicalOr(m_VPValue(Op0), m_VPValue(Op1)));
+ return Ctx.TTI.getArithmeticInstrCost(
+ IsLogicalOr ? Instruction::Or : Instruction::And, VectorTy, CostKind,
+ {Op1VK, Op1VP}, {Op2VK, Op2VP}, Operands, SI);
+ }
+
+ Type *CondTy = Ctx.Types.inferScalarType(getOperand(0));
+ if (!ScalarCond)
+ CondTy = VectorType::get(CondTy, VF);
+
+ CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
+ if (auto *Cmp = dyn_cast<CmpInst>(SI->getCondition()))
+ Pred = Cmp->getPredicate();
+ return Ctx.TTI.getCmpSelInstrCost(Instruction::Select, VectorTy, CondTy, Pred,
+ CostKind, {TTI::OK_AnyValue, TTI::OP_None},
+ {TTI::OK_AnyValue, TTI::OP_None}, SI);
+}
+
VPRecipeWithIRFlags::FastMathFlagsTy::FastMathFlagsTy(
const FastMathFlags &FMF) {
AllowReassoc = FMF.allowReassoc();
More information about the llvm-commits
mailing list