[llvm] fd93a5e - [VPlan] Support match unary and binary recipes in pattern matcher (NFC).
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 18 07:25:56 PDT 2024
Author: Florian Hahn
Date: 2024-03-18T14:24:52Z
New Revision: fd93a5e3c06a90e931c645948aa73ee9894699d7
URL: https://github.com/llvm/llvm-project/commit/fd93a5e3c06a90e931c645948aa73ee9894699d7
DIFF: https://github.com/llvm/llvm-project/commit/fd93a5e3c06a90e931c645948aa73ee9894699d7.diff
LOG: [VPlan] Support match unary and binary recipes in pattern matcher (NFC).
Generalize pattern matchers to take recipe types to match as template
arguments and use it to provide matchers for unary and binary recipes
with specific opcodes and a list of recipe types (VPWidenRecipe,
VPReplicateRecipe, VPWidenCastRecipe, VPInstruction)
The new matchers are used to simplify and generalize the code in
simplifyRecipes.
Added:
Modified:
llvm/lib/Transforms/Vectorize/VPlan.h
llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index af6d0081bffebc..d77c7554d50e4f 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2136,6 +2136,8 @@ class VPReplicateRecipe : public VPRecipeWithIRFlags {
assert(isPredicated() && "Trying to get the mask of a unpredicated recipe");
return getOperand(getNumOperands() - 1);
}
+
+ unsigned getOpcode() const { return getUnderlyingInstr()->getOpcode(); }
};
/// A recipe for generating conditional branches on the bits of a mask.
diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
index b90c588b607564..aa253590694514 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
+++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
@@ -50,13 +50,82 @@ template <typename Class> struct bind_ty {
}
};
+/// Match a specified integer value or vector of all elements of that
+/// value.
+struct specific_intval {
+ APInt Val;
+
+ specific_intval(APInt V) : Val(std::move(V)) {}
+
+ bool match(VPValue *VPV) {
+ if (!VPV->isLiveIn())
+ return false;
+ Value *V = VPV->getLiveInIRValue();
+ const auto *CI = dyn_cast<ConstantInt>(V);
+ if (!CI && V->getType()->isVectorTy())
+ if (const auto *C = dyn_cast<Constant>(V))
+ CI = dyn_cast_or_null<ConstantInt>(
+ C->getSplatValue(/*UndefsAllowed=*/false));
+
+ return CI && APInt::isSameValue(CI->getValue(), Val);
+ }
+};
+
+inline specific_intval m_SpecificInt(uint64_t V) {
+ return specific_intval(APInt(64, V));
+}
+
+/// Matching combinators
+template <typename LTy, typename RTy> struct match_combine_or {
+ LTy L;
+ RTy R;
+
+ match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {}
+
+ template <typename ITy> bool match(ITy *V) {
+ if (L.match(V))
+ return true;
+ if (R.match(V))
+ return true;
+ return false;
+ }
+};
+
+template <typename LTy, typename RTy>
+inline match_combine_or<LTy, RTy> m_CombineOr(const LTy &L, const RTy &R) {
+ return match_combine_or<LTy, RTy>(L, R);
+}
+
/// Match a VPValue, capturing it if we match.
inline bind_ty<VPValue> m_VPValue(VPValue *&V) { return V; }
-template <typename Op0_t, unsigned Opcode> struct UnaryVPInstruction_match {
+namespace detail {
+
+/// A helper to match an opcode against multiple recipe types.
+template <unsigned Opcode, typename...> struct MatchRecipeAndOpcode {};
+
+template <unsigned Opcode, typename RecipeTy>
+struct MatchRecipeAndOpcode<Opcode, RecipeTy> {
+ static bool match(const VPRecipeBase *R) {
+ auto *DefR = dyn_cast<RecipeTy>(R);
+ return DefR && DefR->getOpcode() == Opcode;
+ }
+};
+
+template <unsigned Opcode, typename RecipeTy, typename... RecipeTys>
+struct MatchRecipeAndOpcode<Opcode, RecipeTy, RecipeTys...> {
+ static bool match(const VPRecipeBase *R) {
+ return MatchRecipeAndOpcode<Opcode, RecipeTy>::match(R) ||
+ MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R);
+ }
+};
+} // namespace detail
+
+template <typename Op0_t, unsigned Opcode, typename... RecipeTys>
+struct UnaryRecipe_match {
Op0_t Op0;
- UnaryVPInstruction_match(Op0_t Op0) : Op0(Op0) {}
+ UnaryRecipe_match(Op0_t Op0) : Op0(Op0) {}
bool match(const VPValue *V) {
auto *DefR = V->getDefiningRecipe();
@@ -64,37 +133,58 @@ template <typename Op0_t, unsigned Opcode> struct UnaryVPInstruction_match {
}
bool match(const VPRecipeBase *R) {
- auto *DefR = dyn_cast<VPInstruction>(R);
- if (!DefR || DefR->getOpcode() != Opcode)
+ if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
return false;
- assert(DefR->getNumOperands() == 1 &&
+ assert(R->getNumOperands() == 1 &&
"recipe with matched opcode does not have 1 operands");
- return Op0.match(DefR->getOperand(0));
+ return Op0.match(R->getOperand(0));
}
};
-template <typename Op0_t, typename Op1_t, unsigned Opcode>
-struct BinaryVPInstruction_match {
+template <typename Op0_t, unsigned Opcode>
+using UnaryVPInstruction_match =
+ UnaryRecipe_match<Op0_t, Opcode, VPInstruction>;
+
+template <typename Op0_t, unsigned Opcode>
+using AllUnaryRecipe_match =
+ UnaryRecipe_match<Op0_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
+ VPWidenCastRecipe, VPInstruction>;
+
+template <typename Op0_t, typename Op1_t, unsigned Opcode,
+ typename... RecipeTys>
+struct BinaryRecipe_match {
Op0_t Op0;
Op1_t Op1;
- BinaryVPInstruction_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
+ BinaryRecipe_match(Op0_t Op0, Op1_t Op1) : Op0(Op0), Op1(Op1) {}
bool match(const VPValue *V) {
auto *DefR = V->getDefiningRecipe();
return DefR && match(DefR);
}
+ bool match(const VPSingleDefRecipe *R) {
+ return match(static_cast<const VPRecipeBase *>(R));
+ }
+
bool match(const VPRecipeBase *R) {
- auto *DefR = dyn_cast<VPInstruction>(R);
- if (!DefR || DefR->getOpcode() != Opcode)
+ if (!detail::MatchRecipeAndOpcode<Opcode, RecipeTys...>::match(R))
return false;
- assert(DefR->getNumOperands() == 2 &&
+ assert(R->getNumOperands() == 2 &&
"recipe with matched opcode does not have 2 operands");
- return Op0.match(DefR->getOperand(0)) && Op1.match(DefR->getOperand(1));
+ return Op0.match(R->getOperand(0)) && Op1.match(R->getOperand(1));
}
};
+template <typename Op0_t, typename Op1_t, unsigned Opcode>
+using BinaryVPInstruction_match =
+ BinaryRecipe_match<Op0_t, Op1_t, Opcode, VPInstruction>;
+
+template <typename Op0_t, typename Op1_t, unsigned Opcode>
+using AllBinaryRecipe_match =
+ BinaryRecipe_match<Op0_t, Op1_t, Opcode, VPWidenRecipe, VPReplicateRecipe,
+ VPWidenCastRecipe, VPInstruction>;
+
template <unsigned Opcode, typename Op0_t>
inline UnaryVPInstruction_match<Op0_t, Opcode>
m_VPInstruction(const Op0_t &Op0) {
@@ -130,6 +220,47 @@ inline BinaryVPInstruction_match<Op0_t, Op1_t, VPInstruction::BranchOnCount>
m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) {
return m_VPInstruction<VPInstruction::BranchOnCount>(Op0, Op1);
}
+
+template <unsigned Opcode, typename Op0_t>
+inline AllUnaryRecipe_match<Op0_t, Opcode> m_Unary(const Op0_t &Op0) {
+ return AllUnaryRecipe_match<Op0_t, Opcode>(Op0);
+}
+
+template <typename Op0_t>
+inline AllUnaryRecipe_match<Op0_t, Instruction::Trunc>
+m_Trunc(const Op0_t &Op0) {
+ return m_Unary<Instruction::Trunc, Op0_t>(Op0);
+}
+
+template <typename Op0_t>
+inline AllUnaryRecipe_match<Op0_t, Instruction::ZExt> m_ZExt(const Op0_t &Op0) {
+ return m_Unary<Instruction::ZExt, Op0_t>(Op0);
+}
+
+template <typename Op0_t>
+inline AllUnaryRecipe_match<Op0_t, Instruction::SExt> m_SExt(const Op0_t &Op0) {
+ return m_Unary<Instruction::SExt, Op0_t>(Op0);
+}
+
+template <typename Op0_t>
+inline match_combine_or<AllUnaryRecipe_match<Op0_t, Instruction::ZExt>,
+ AllUnaryRecipe_match<Op0_t, Instruction::SExt>>
+m_ZExtOrSExt(const Op0_t &Op0) {
+ return m_CombineOr(m_ZExt(Op0), m_SExt(Op0));
+}
+
+template <unsigned Opcode, typename Op0_t, typename Op1_t>
+inline AllBinaryRecipe_match<Op0_t, Op1_t, Opcode> m_Binary(const Op0_t &Op0,
+ const Op1_t &Op1) {
+ return AllBinaryRecipe_match<Op0_t, Op1_t, Opcode>(Op0, Op1);
+}
+
+template <typename Op0_t, typename Op1_t>
+inline AllBinaryRecipe_match<Op0_t, Op1_t, Instruction::Mul>
+m_Mul(const Op0_t &Op0, const Op1_t &Op1) {
+ return m_Binary<Instruction::Mul, Op0_t, Op1_t>(Op0, Op1);
+}
+
} // namespace VPlanPatternMatch
} // namespace llvm
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 0fc98fb69791d4..a91ccefe4b6d7d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -814,27 +814,6 @@ void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) {
}
}
-/// Returns true is \p V is constant one.
-static bool isConstantOne(VPValue *V) {
- if (!V->isLiveIn())
- return false;
- auto *C = dyn_cast<ConstantInt>(V->getLiveInIRValue());
- return C && C->isOne();
-}
-
-/// Returns the llvm::Instruction opcode for \p R.
-static unsigned getOpcodeForRecipe(VPRecipeBase &R) {
- if (auto *WidenR = dyn_cast<VPWidenRecipe>(&R))
- return WidenR->getUnderlyingInstr()->getOpcode();
- if (auto *WidenC = dyn_cast<VPWidenCastRecipe>(&R))
- return WidenC->getOpcode();
- if (auto *RepR = dyn_cast<VPReplicateRecipe>(&R))
- return RepR->getUnderlyingInstr()->getOpcode();
- if (auto *VPI = dyn_cast<VPInstruction>(&R))
- return VPI->getOpcode();
- return 0;
-}
-
/// Try to simplify recipe \p R.
static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
// Try to remove redundant blend recipes.
@@ -848,24 +827,9 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
return;
}
- switch (getOpcodeForRecipe(R)) {
- case Instruction::Mul: {
- VPValue *A = R.getOperand(0);
- VPValue *B = R.getOperand(1);
- if (isConstantOne(A))
- return R.getVPSingleValue()->replaceAllUsesWith(B);
- if (isConstantOne(B))
- return R.getVPSingleValue()->replaceAllUsesWith(A);
- break;
- }
- case Instruction::Trunc: {
- VPRecipeBase *Ext = R.getOperand(0)->getDefiningRecipe();
- if (!Ext)
- break;
- unsigned ExtOpcode = getOpcodeForRecipe(*Ext);
- if (ExtOpcode != Instruction::ZExt && ExtOpcode != Instruction::SExt)
- break;
- VPValue *A = Ext->getOperand(0);
+ using namespace llvm::VPlanPatternMatch;
+ VPValue *A;
+ if (match(&R, m_Trunc(m_ZExtOrSExt(m_VPValue(A))))) {
VPValue *Trunc = R.getVPSingleValue();
Type *TruncTy = TypeInfo.inferScalarType(Trunc);
Type *ATy = TypeInfo.inferScalarType(A);
@@ -874,8 +838,12 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
} else {
// Don't replace a scalarizing recipe with a widened cast.
if (isa<VPReplicateRecipe>(&R))
- break;
+ return;
if (ATy->getScalarSizeInBits() < TruncTy->getScalarSizeInBits()) {
+
+ unsigned ExtOpcode = match(R.getOperand(0), m_SExt(m_VPValue()))
+ ? Instruction::SExt
+ : Instruction::ZExt;
auto *VPC =
new VPWidenCastRecipe(Instruction::CastOps(ExtOpcode), A, TruncTy);
VPC->insertBefore(&R);
@@ -901,11 +869,11 @@ static void simplifyRecipe(VPRecipeBase &R, VPTypeAnalysis &TypeInfo) {
assert(TypeInfo.inferScalarType(VPV) == TypeInfo2.inferScalarType(VPV));
}
#endif
- break;
- }
- default:
- break;
}
+
+ if (match(&R, m_CombineOr(m_Mul(m_VPValue(A), m_SpecificInt(1)),
+ m_Mul(m_SpecificInt(1), m_VPValue(A)))))
+ return R.getVPSingleValue()->replaceAllUsesWith(A);
}
/// Try to simplify the recipes in \p Plan.
More information about the llvm-commits
mailing list