[llvm] SD Pattern Match: Operands patterns with VP Context (PR #103308)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 13 09:41:27 PDT 2024
https://github.com/v01dXYZ created https://github.com/llvm/llvm-project/pull/103308
Currently, when using a VP match context with `sd_context_match`, only Opcode matching is possible (`m_Opc(Opcode)`).
This PR suggest a way to make patterns with Operands (eg `m_Node`, `m_Add`, ...) works with a VP context.
This PR blocks another PR https://github.com/llvm/llvm-project/pull/102877.
>From b52012694b0894ed4b9ef41287325d4378ab7539 Mon Sep 17 00:00:00 2001
From: v01dxyz <v01dxyz at v01d.xyz>
Date: Mon, 12 Aug 2024 13:04:52 +0200
Subject: [PATCH] [SDPatternMatch][VP] Operands pattern: Support VPMatchContext
Ignore last two operands if Node is VP
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 11 +++--
llvm/lib/CodeGen/SelectionDAG/MatchContext.h | 4 ++
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 41 +++++++++++++++++++
3 files changed, 52 insertions(+), 4 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 96ece1559bc437..b30efc9a25a39e 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -47,6 +47,8 @@ class BasicMatchContext {
bool match(SDValue N, unsigned Opcode) const {
return N->getOpcode() == Opcode;
}
+
+ static constexpr bool IsVP = false;
};
template <typename Pattern, typename MatchContext>
@@ -390,7 +392,8 @@ template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
// Returns false if there are more operands than predicates;
- return N->getNumOperands() == OpIdx;
+ // Ignores the last two operands if both the Context and the Node are VP
+ return N->getNumOperands() == (OpIdx + 2 * Ctx.IsVP * N->isVPOpcode());
}
};
@@ -464,7 +467,7 @@ struct TernaryOpc_match {
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N);
- assert(EO.Size == 3);
+ assert(EO.Size == 3U + 2 * N->isVPOpcode());
return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
(Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
@@ -516,7 +519,7 @@ struct BinaryOpc_match {
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N);
- assert(EO.Size == 2);
+ assert(EO.Size == 2U + 2 * N->isVPOpcode());
return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
(Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
@@ -668,7 +671,7 @@ template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N);
- assert(EO.Size == 1);
+ assert(EO.Size == 1U + 2 * N->isVPOpcode());
return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/MatchContext.h b/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
index f965cb952f97a2..c1b3f7259aae33 100644
--- a/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
+++ b/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
@@ -46,6 +46,8 @@ class EmptyMatchContext {
bool LegalOnly = false) const {
return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
}
+
+ static constexpr bool IsVP = false;
};
class VPMatchContext {
@@ -170,6 +172,8 @@ class VPMatchContext {
unsigned VPOp = ISD::getVPForBaseOpcode(Op);
return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
}
+
+ static constexpr bool IsVP = true;
};
} // end anonymous namespace
#endif
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 074247e6e7d184..66a2ec189dc199 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -383,6 +383,8 @@ struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
return BaseOpc.has_value() && *BaseOpc == Opc;
}
+
+ static constexpr bool IsVP = true;
};
} // anonymous namespace
TEST_F(SelectionDAGPatternMatchTest, matchContext) {
@@ -400,15 +402,54 @@ TEST_F(SelectionDAGPatternMatchTest, matchContext) {
{Vector0, Vector0, Mask0, Scalar0});
SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
{Scalar0, VPAdd, Mask0, Scalar0});
+ SDValue Add = DAG->getNode(ISD::ADD, DL, VInt32VT, {Vector0, Vector0});
using namespace SDPatternMatch;
VPMatchContext VPCtx(DAG.get());
EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
+ EXPECT_TRUE(
+ sd_context_match(VPAdd, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
+ // VPMatchContext can't match pattern using explicit VP Opcode
+ EXPECT_FALSE(sd_context_match(VPAdd, VPCtx,
+ m_Node(ISD::VP_ADD, m_Value(), m_Value())));
+ EXPECT_FALSE(sd_context_match(
+ VPAdd, VPCtx,
+ m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
+ // Check Binary Op Pattern
+ EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Add(m_Value(), m_Value())));
// VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
// sd_match before switching to VPMatchContext when checking VPAdd.
EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
m_Context(VPCtx, m_Opc(ISD::ADD)),
m_Value(), m_Value())));
+ // non-vector predicated should match too
+ EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Opc(ISD::ADD)));
+ EXPECT_TRUE(
+ sd_context_match(Add, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
+ EXPECT_FALSE(sd_context_match(
+ Add, VPCtx,
+ m_Node(ISD::ADD, m_Value(), m_Value(), m_Value(), m_Value())));
+ EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Add(m_Value(), m_Value())));
+}
+
+TEST_F(SelectionDAGPatternMatchTest, matchVPWithBasicContext) {
+ SDLoc DL;
+ auto BoolVT = EVT::getIntegerVT(Context, 1);
+ auto Int32VT = EVT::getIntegerVT(Context, 32);
+ auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+ auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);
+
+ SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
+ SDValue Mask = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, MaskVT);
+ SDValue EL = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
+
+ SDValue VPAdd =
+ DAG->getNode(ISD::VP_ADD, DL, VInt32VT, Vector0, Vector0, Mask, EL);
+
+ using namespace SDPatternMatch;
+ EXPECT_FALSE(sd_match(VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
}
TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
More information about the llvm-commits
mailing list