[llvm] SD Pattern Match: Operands patterns with VP Context (PR #103308)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 14 10:08:19 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: None (v01dXYZ)
<details>
<summary>Changes</summary>
Currently, when using a VP match context with `sd_context_match`, only Opcode matching is possible (`m_Opc(Opcode)`).
This PR suggest a way to make patterns with Operands (eg `m_Node`, `m_Add`, ...) works with a VP context.
This PR blocks another PR https://github.com/llvm/llvm-project/pull/102877.
---
Full diff: https://github.com/llvm/llvm-project/pull/103308.diff
3 Files Affected:
- (modified) llvm/include/llvm/CodeGen/SDPatternMatch.h (+13-7)
- (modified) llvm/lib/CodeGen/SelectionDAG/MatchContext.h (+6)
- (modified) llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp (+43)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 96ece1559bc437..c98e2bd9102ec8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -47,6 +47,8 @@ class BasicMatchContext {
bool match(SDValue N, unsigned Opcode) const {
return N->getOpcode() == Opcode;
}
+
+ unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
};
template <typename Pattern, typename MatchContext>
@@ -390,7 +392,8 @@ template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
// Returns false if there are more operands than predicates;
- return N->getNumOperands() == OpIdx;
+ // Ignores the last two operands if both the Context and the Node are VP
+ return Ctx.getNumOperands(N) == OpIdx;
}
};
@@ -424,8 +427,9 @@ template <bool ExcludeChain> struct EffectiveOperands {
unsigned Size = 0;
unsigned FirstIndex = 0;
- explicit EffectiveOperands(SDValue N) {
- const unsigned TotalNumOps = N->getNumOperands();
+ template <typename MatchContext>
+ explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) {
+ const unsigned TotalNumOps = Ctx.getNumOperands(N);
FirstIndex = TotalNumOps;
for (unsigned I = 0; I < TotalNumOps; ++I) {
// Count the number of non-chain and non-glue nodes (we ignore chain
@@ -444,7 +448,9 @@ template <> struct EffectiveOperands<false> {
unsigned Size = 0;
unsigned FirstIndex = 0;
- explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
+ template <typename MatchContext>
+ explicit EffectiveOperands(SDValue N, const MatchContext &Ctx)
+ : Size(Ctx.getNumOperands(N)) {}
};
// === Ternary operations ===
@@ -463,7 +469,7 @@ struct TernaryOpc_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
- EffectiveOperands<ExcludeChain> EO(N);
+ EffectiveOperands<ExcludeChain> EO(N, Ctx);
assert(EO.Size == 3);
return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
@@ -515,7 +521,7 @@ struct BinaryOpc_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
- EffectiveOperands<ExcludeChain> EO(N);
+ EffectiveOperands<ExcludeChain> EO(N, Ctx);
assert(EO.Size == 2);
return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
@@ -667,7 +673,7 @@ template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
- EffectiveOperands<ExcludeChain> EO(N);
+ EffectiveOperands<ExcludeChain> EO(N, Ctx);
assert(EO.Size == 1);
return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/MatchContext.h b/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
index f965cb952f97a2..ef748688cea7ba 100644
--- a/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
+++ b/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
@@ -46,6 +46,8 @@ class EmptyMatchContext {
bool LegalOnly = false) const {
return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
}
+
+ unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
};
class VPMatchContext {
@@ -170,6 +172,10 @@ class VPMatchContext {
unsigned VPOp = ISD::getVPForBaseOpcode(Op);
return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
}
+
+ unsigned getNumOperands(SDValue N) const {
+ return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
+ }
};
} // end anonymous namespace
#endif
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 074247e6e7d184..3c47fb085b14c0 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -383,6 +383,10 @@ struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
return BaseOpc.has_value() && *BaseOpc == Opc;
}
+
+ unsigned getNumOperands(SDValue N) const {
+ return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
+ }
};
} // anonymous namespace
TEST_F(SelectionDAGPatternMatchTest, matchContext) {
@@ -400,15 +404,54 @@ TEST_F(SelectionDAGPatternMatchTest, matchContext) {
{Vector0, Vector0, Mask0, Scalar0});
SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
{Scalar0, VPAdd, Mask0, Scalar0});
+ SDValue Add = DAG->getNode(ISD::ADD, DL, VInt32VT, {Vector0, Vector0});
using namespace SDPatternMatch;
VPMatchContext VPCtx(DAG.get());
EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
+ EXPECT_TRUE(
+ sd_context_match(VPAdd, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
+ // VPMatchContext can't match pattern using explicit VP Opcode
+ EXPECT_FALSE(sd_context_match(VPAdd, VPCtx,
+ m_Node(ISD::VP_ADD, m_Value(), m_Value())));
+ EXPECT_FALSE(sd_context_match(
+ VPAdd, VPCtx,
+ m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
+ // Check Binary Op Pattern
+ EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Add(m_Value(), m_Value())));
// VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
// sd_match before switching to VPMatchContext when checking VPAdd.
EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
m_Context(VPCtx, m_Opc(ISD::ADD)),
m_Value(), m_Value())));
+ // non-vector predicated should match too
+ EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Opc(ISD::ADD)));
+ EXPECT_TRUE(
+ sd_context_match(Add, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
+ EXPECT_FALSE(sd_context_match(
+ Add, VPCtx,
+ m_Node(ISD::ADD, m_Value(), m_Value(), m_Value(), m_Value())));
+ EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Add(m_Value(), m_Value())));
+}
+
+TEST_F(SelectionDAGPatternMatchTest, matchVPWithBasicContext) {
+ SDLoc DL;
+ auto BoolVT = EVT::getIntegerVT(Context, 1);
+ auto Int32VT = EVT::getIntegerVT(Context, 32);
+ auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+ auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);
+
+ SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
+ SDValue Mask = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, MaskVT);
+ SDValue EL = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
+
+ SDValue VPAdd =
+ DAG->getNode(ISD::VP_ADD, DL, VInt32VT, Vector0, Vector0, Mask, EL);
+
+ using namespace SDPatternMatch;
+ EXPECT_FALSE(sd_match(VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
}
TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/103308
More information about the llvm-commits
mailing list