[llvm] SD Pattern Match: Operands patterns with VP Context (PR #103308)

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 14 10:08:19 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: None (v01dXYZ)

<details>
<summary>Changes</summary>

Currently, when using a VP match context with `sd_context_match`, only Opcode matching is possible (`m_Opc(Opcode)`).

This PR suggest a way to make patterns with Operands (eg `m_Node`, `m_Add`, ...) works with a VP context.

This PR blocks another PR https://github.com/llvm/llvm-project/pull/102877.

---
Full diff: https://github.com/llvm/llvm-project/pull/103308.diff


3 Files Affected:

- (modified) llvm/include/llvm/CodeGen/SDPatternMatch.h (+13-7) 
- (modified) llvm/lib/CodeGen/SelectionDAG/MatchContext.h (+6) 
- (modified) llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp (+43) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 96ece1559bc437..c98e2bd9102ec8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -47,6 +47,8 @@ class BasicMatchContext {
   bool match(SDValue N, unsigned Opcode) const {
     return N->getOpcode() == Opcode;
   }
+
+  unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
 };
 
 template <typename Pattern, typename MatchContext>
@@ -390,7 +392,8 @@ template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     // Returns false if there are more operands than predicates;
-    return N->getNumOperands() == OpIdx;
+    // Ignores the last two operands if both the Context and the Node are VP
+    return Ctx.getNumOperands(N) == OpIdx;
   }
 };
 
@@ -424,8 +427,9 @@ template <bool ExcludeChain> struct EffectiveOperands {
   unsigned Size = 0;
   unsigned FirstIndex = 0;
 
-  explicit EffectiveOperands(SDValue N) {
-    const unsigned TotalNumOps = N->getNumOperands();
+  template <typename MatchContext>
+  explicit EffectiveOperands(SDValue N, const MatchContext &Ctx) {
+    const unsigned TotalNumOps = Ctx.getNumOperands(N);
     FirstIndex = TotalNumOps;
     for (unsigned I = 0; I < TotalNumOps; ++I) {
       // Count the number of non-chain and non-glue nodes (we ignore chain
@@ -444,7 +448,9 @@ template <> struct EffectiveOperands<false> {
   unsigned Size = 0;
   unsigned FirstIndex = 0;
 
-  explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
+  template <typename MatchContext>
+  explicit EffectiveOperands(SDValue N, const MatchContext &Ctx)
+      : Size(Ctx.getNumOperands(N)) {}
 };
 
 // === Ternary operations ===
@@ -463,7 +469,7 @@ struct TernaryOpc_match {
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
-      EffectiveOperands<ExcludeChain> EO(N);
+      EffectiveOperands<ExcludeChain> EO(N, Ctx);
       assert(EO.Size == 3);
       return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
                Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
@@ -515,7 +521,7 @@ struct BinaryOpc_match {
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
-      EffectiveOperands<ExcludeChain> EO(N);
+      EffectiveOperands<ExcludeChain> EO(N, Ctx);
       assert(EO.Size == 2);
       return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
               RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
@@ -667,7 +673,7 @@ template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
-      EffectiveOperands<ExcludeChain> EO(N);
+      EffectiveOperands<ExcludeChain> EO(N, Ctx);
       assert(EO.Size == 1);
       return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
     }
diff --git a/llvm/lib/CodeGen/SelectionDAG/MatchContext.h b/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
index f965cb952f97a2..ef748688cea7ba 100644
--- a/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
+++ b/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
@@ -46,6 +46,8 @@ class EmptyMatchContext {
                                 bool LegalOnly = false) const {
     return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
   }
+
+  unsigned getNumOperands(SDValue N) const { return N->getNumOperands(); }
 };
 
 class VPMatchContext {
@@ -170,6 +172,10 @@ class VPMatchContext {
     unsigned VPOp = ISD::getVPForBaseOpcode(Op);
     return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
   }
+
+  unsigned getNumOperands(SDValue N) const {
+    return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
+  }
 };
 } // end anonymous namespace
 #endif
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 074247e6e7d184..3c47fb085b14c0 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -383,6 +383,10 @@ struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
     auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
     return BaseOpc.has_value() && *BaseOpc == Opc;
   }
+
+  unsigned getNumOperands(SDValue N) const {
+    return N->isVPOpcode() ? N->getNumOperands() - 2 : N->getNumOperands();
+  }
 };
 } // anonymous namespace
 TEST_F(SelectionDAGPatternMatchTest, matchContext) {
@@ -400,15 +404,54 @@ TEST_F(SelectionDAGPatternMatchTest, matchContext) {
                                {Vector0, Vector0, Mask0, Scalar0});
   SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
                                      {Scalar0, VPAdd, Mask0, Scalar0});
+  SDValue Add = DAG->getNode(ISD::ADD, DL, VInt32VT, {Vector0, Vector0});
 
   using namespace SDPatternMatch;
   VPMatchContext VPCtx(DAG.get());
   EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
+  EXPECT_TRUE(
+      sd_context_match(VPAdd, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
+  // VPMatchContext can't match pattern using explicit VP Opcode
+  EXPECT_FALSE(sd_context_match(VPAdd, VPCtx,
+                                m_Node(ISD::VP_ADD, m_Value(), m_Value())));
+  EXPECT_FALSE(sd_context_match(
+      VPAdd, VPCtx,
+      m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
+  // Check Binary Op Pattern
+  EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Add(m_Value(), m_Value())));
   // VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
   // sd_match before switching to VPMatchContext when checking VPAdd.
   EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
                                            m_Context(VPCtx, m_Opc(ISD::ADD)),
                                            m_Value(), m_Value())));
+  // non-vector predicated should match too
+  EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Opc(ISD::ADD)));
+  EXPECT_TRUE(
+      sd_context_match(Add, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
+  EXPECT_FALSE(sd_context_match(
+      Add, VPCtx,
+      m_Node(ISD::ADD, m_Value(), m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Add(m_Value(), m_Value())));
+}
+
+TEST_F(SelectionDAGPatternMatchTest, matchVPWithBasicContext) {
+  SDLoc DL;
+  auto BoolVT = EVT::getIntegerVT(Context, 1);
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+  auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);
+
+  SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
+  SDValue Mask = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, MaskVT);
+  SDValue EL = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
+
+  SDValue VPAdd =
+      DAG->getNode(ISD::VP_ADD, DL, VInt32VT, Vector0, Vector0, Mask, EL);
+
+  using namespace SDPatternMatch;
+  EXPECT_FALSE(sd_match(VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/103308


More information about the llvm-commits mailing list