[llvm] fc1b019 - [DAG] SD Pattern Match: Operands patterns with VP Context (#103308)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 16 01:46:23 PDT 2024
Author: v01dXYZ
Date: 2024-08-16T09:46:20+01:00
New Revision: fc1b01963857b5c04980c713145a71d6b858ad8a
URL: https://github.com/llvm/llvm-project/commit/fc1b01963857b5c04980c713145a71d6b858ad8a
DIFF: https://github.com/llvm/llvm-project/commit/fc1b01963857b5c04980c713145a71d6b858ad8a.diff
LOG: [DAG] SD Pattern Match: Operands patterns with VP Context (#103308)
Currently, when using a VP match context with `sd_context_match`, only Opcode matching is possible (`m_Opc(Opcode)`).
This PR suggest a way to make patterns with Operands (eg `m_Node`, `m_Add`, ...) works with a VP context.
This PR blocks another PR https://github.com/llvm/llvm-project/pull/102877.
Co-authored-by: v01dxyz <v01dxyz at v01d.xyz>
Added:
Modified:
llvm/include/llvm/CodeGen/SDPatternMatch.h
llvm/lib/CodeGen/SelectionDAG/MatchContext.h
llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 88ddd43a2a8913..b1aa87ca2d3e13 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -47,6 +47,8 @@ class BasicMatchContext {
bool match(SDValue N, unsigned Opcode) const {
return N->getOpcode() == Opcode;
}
+
+ unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
};
template <typename Pattern, typename MatchContext>
@@ -390,7 +392,8 @@ template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
// Returns false if there are more operands than predicates;
- return N->getNumOperands() == OpIdx;
+ // Ignores the last two operands if both the Context and the Node are VP
+ return Ctx.getNumOperands(N) == OpIdx;
}
};
@@ -424,8 +427,9 @@ template <bool ExcludeChain> struct EffectiveOperands {
unsigned Size = 0;
unsigned FirstIndex = 0;
- explicit EffectiveOperands(SDValue N) {
- const unsigned TotalNumOps = N->getNumOperands();
+ template <typename MatchContext>
+ explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) {
+ const unsigned TotalNumOps = Ctx.getNumOperands(N);
FirstIndex = TotalNumOps;
for (unsigned I = 0; I < TotalNumOps; ++I) {
// Count the number of non-chain and non-glue nodes (we ignore chain
@@ -444,7 +448,9 @@ template <> struct EffectiveOperands<false> {
unsigned Size = 0;
unsigned FirstIndex = 0;
- explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
+ template <typename MatchContext>
+ explicit EffectiveOperands(SDValue N, const MatchContext &Ctx)
+ : Size(Ctx.getNumOperands(N)) {}
};
// === Ternary operations ===
@@ -463,7 +469,7 @@ struct TernaryOpc_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
- EffectiveOperands<ExcludeChain> EO(N);
+ EffectiveOperands<ExcludeChain> EO(N, Ctx);
assert(EO.Size == 3);
return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
@@ -515,7 +521,7 @@ struct BinaryOpc_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
- EffectiveOperands<ExcludeChain> EO(N);
+ EffectiveOperands<ExcludeChain> EO(N, Ctx);
assert(EO.Size == 2);
return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
@@ -667,7 +673,7 @@ template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
- EffectiveOperands<ExcludeChain> EO(N);
+ EffectiveOperands<ExcludeChain> EO(N, Ctx);
assert(EO.Size == 1);
return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/MatchContext.h b/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
index d81a0c2e87e7d2..8f03532af99e86 100644
--- a/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
+++ b/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
@@ -45,6 +45,8 @@ class EmptyMatchContext {
bool LegalOnly = false) const {
return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
}
+
+ unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
};
class VPMatchContext {
@@ -169,6 +171,10 @@ class VPMatchContext {
unsigned VPOp = ISD::getVPForBaseOpcode(Op);
return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
}
+
+ unsigned getNumOperands(SDValue N) const {
+ return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
+ }
};
} // namespace llvm
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index b9414be98623af..c04fc5621ab499 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -393,6 +393,10 @@ struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
return BaseOpc.has_value() && *BaseOpc == Opc;
}
+
+ unsigned getNumOperands(SDValue N) const {
+ return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
+ }
};
} // anonymous namespace
TEST_F(SelectionDAGPatternMatchTest, matchContext) {
@@ -410,15 +414,54 @@ TEST_F(SelectionDAGPatternMatchTest, matchContext) {
{Vector0, Vector0, Mask0, Scalar0});
SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
{Scalar0, VPAdd, Mask0, Scalar0});
+ SDValue Add = DAG->getNode(ISD::ADD, DL, VInt32VT, {Vector0, Vector0});
using namespace SDPatternMatch;
VPMatchContext VPCtx(DAG.get());
EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
+ EXPECT_TRUE(
+ sd_context_match(VPAdd, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
+ // VPMatchContext can't match pattern using explicit VP Opcode
+ EXPECT_FALSE(sd_context_match(VPAdd, VPCtx,
+ m_Node(ISD::VP_ADD, m_Value(), m_Value())));
+ EXPECT_FALSE(sd_context_match(
+ VPAdd, VPCtx,
+ m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
+ // Check Binary Op Pattern
+ EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Add(m_Value(), m_Value())));
// VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
// sd_match before switching to VPMatchContext when checking VPAdd.
EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
m_Context(VPCtx, m_Opc(ISD::ADD)),
m_Value(), m_Value())));
+ // non-vector predicated should match too
+ EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Opc(ISD::ADD)));
+ EXPECT_TRUE(
+ sd_context_match(Add, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
+ EXPECT_FALSE(sd_context_match(
+ Add, VPCtx,
+ m_Node(ISD::ADD, m_Value(), m_Value(), m_Value(), m_Value())));
+ EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Add(m_Value(), m_Value())));
+}
+
+TEST_F(SelectionDAGPatternMatchTest, matchVPWithBasicContext) {
+ SDLoc DL;
+ auto BoolVT = EVT::getIntegerVT(Context, 1);
+ auto Int32VT = EVT::getIntegerVT(Context, 32);
+ auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+ auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);
+
+ SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
+ SDValue Mask = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, MaskVT);
+ SDValue EL = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
+
+ SDValue VPAdd =
+ DAG->getNode(ISD::VP_ADD, DL, VInt32VT, Vector0, Vector0, Mask, EL);
+
+ using namespace SDPatternMatch;
+ EXPECT_FALSE(sd_match(VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
}
TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
More information about the llvm-commits
mailing list