[llvm] SD Pattern Match: Operands patterns with VP Context (PR #103308)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 13 09:41:27 PDT 2024


https://github.com/v01dXYZ created https://github.com/llvm/llvm-project/pull/103308

Currently, when using a VP match context with `sd_context_match`, only Opcode matching is possible (`m_Opc(Opcode)`).

This PR suggest a way to make patterns with Operands (eg `m_Node`, `m_Add`, ...) works with a VP context.

This PR blocks another PR https://github.com/llvm/llvm-project/pull/102877.

>From b52012694b0894ed4b9ef41287325d4378ab7539 Mon Sep 17 00:00:00 2001
From: v01dxyz <v01dxyz at v01d.xyz>
Date: Mon, 12 Aug 2024 13:04:52 +0200
Subject: [PATCH] [SDPatternMatch][VP] Operands pattern: Support VPMatchContext

Ignore last two operands if Node is VP
---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 11 +++--
 llvm/lib/CodeGen/SelectionDAG/MatchContext.h  |  4 ++
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  | 41 +++++++++++++++++++
 3 files changed, 52 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 96ece1559bc437..b30efc9a25a39e 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -47,6 +47,8 @@ class BasicMatchContext {
   bool match(SDValue N, unsigned Opcode) const {
     return N->getOpcode() == Opcode;
   }
+
+  static constexpr bool IsVP = false;
 };
 
 template <typename Pattern, typename MatchContext>
@@ -390,7 +392,8 @@ template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     // Returns false if there are more operands than predicates;
-    return N->getNumOperands() == OpIdx;
+    // Ignores the last two operands if both the Context and the Node are VP
+    return N->getNumOperands() == (OpIdx + 2 * Ctx.IsVP * N->isVPOpcode());
   }
 };
 
@@ -464,7 +467,7 @@ struct TernaryOpc_match {
   bool match(const MatchContext &Ctx, SDValue N) {
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
       EffectiveOperands<ExcludeChain> EO(N);
-      assert(EO.Size == 3);
+      assert(EO.Size == 3U + 2 * N->isVPOpcode());
       return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
                Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
               (Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
@@ -516,7 +519,7 @@ struct BinaryOpc_match {
   bool match(const MatchContext &Ctx, SDValue N) {
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
       EffectiveOperands<ExcludeChain> EO(N);
-      assert(EO.Size == 2);
+      assert(EO.Size == 2U + 2 * N->isVPOpcode());
       return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
               RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
              (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
@@ -668,7 +671,7 @@ template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
   bool match(const MatchContext &Ctx, SDValue N) {
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
       EffectiveOperands<ExcludeChain> EO(N);
-      assert(EO.Size == 1);
+      assert(EO.Size == 1U + 2 * N->isVPOpcode());
       return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
     }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/MatchContext.h b/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
index f965cb952f97a2..c1b3f7259aae33 100644
--- a/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
+++ b/llvm/lib/CodeGen/SelectionDAG/MatchContext.h
@@ -46,6 +46,8 @@ class EmptyMatchContext {
                                 bool LegalOnly = false) const {
     return TLI.isOperationLegalOrCustom(Op, VT, LegalOnly);
   }
+
+  static constexpr bool IsVP = false;
 };
 
 class VPMatchContext {
@@ -170,6 +172,8 @@ class VPMatchContext {
     unsigned VPOp = ISD::getVPForBaseOpcode(Op);
     return TLI.isOperationLegalOrCustom(VPOp, VT, LegalOnly);
   }
+
+  static constexpr bool IsVP = true;
 };
 } // end anonymous namespace
 #endif
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 074247e6e7d184..66a2ec189dc199 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -383,6 +383,8 @@ struct VPMatchContext : public SDPatternMatch::BasicMatchContext {
     auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false);
     return BaseOpc.has_value() && *BaseOpc == Opc;
   }
+
+  static constexpr bool IsVP = true;
 };
 } // anonymous namespace
 TEST_F(SelectionDAGPatternMatchTest, matchContext) {
@@ -400,15 +402,54 @@ TEST_F(SelectionDAGPatternMatchTest, matchContext) {
                                {Vector0, Vector0, Mask0, Scalar0});
   SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT,
                                      {Scalar0, VPAdd, Mask0, Scalar0});
+  SDValue Add = DAG->getNode(ISD::ADD, DL, VInt32VT, {Vector0, Vector0});
 
   using namespace SDPatternMatch;
   VPMatchContext VPCtx(DAG.get());
   EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD)));
+  EXPECT_TRUE(
+      sd_context_match(VPAdd, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
+  // VPMatchContext can't match pattern using explicit VP Opcode
+  EXPECT_FALSE(sd_context_match(VPAdd, VPCtx,
+                                m_Node(ISD::VP_ADD, m_Value(), m_Value())));
+  EXPECT_FALSE(sd_context_match(
+      VPAdd, VPCtx,
+      m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
+  // Check Binary Op Pattern
+  EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Add(m_Value(), m_Value())));
   // VP_REDUCE_ADD doesn't have a based opcode, so we use a normal
   // sd_match before switching to VPMatchContext when checking VPAdd.
   EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(),
                                            m_Context(VPCtx, m_Opc(ISD::ADD)),
                                            m_Value(), m_Value())));
+  // non-vector predicated should match too
+  EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Opc(ISD::ADD)));
+  EXPECT_TRUE(
+      sd_context_match(Add, VPCtx, m_Node(ISD::ADD, m_Value(), m_Value())));
+  EXPECT_FALSE(sd_context_match(
+      Add, VPCtx,
+      m_Node(ISD::ADD, m_Value(), m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_context_match(Add, VPCtx, m_Add(m_Value(), m_Value())));
+}
+
+TEST_F(SelectionDAGPatternMatchTest, matchVPWithBasicContext) {
+  SDLoc DL;
+  auto BoolVT = EVT::getIntegerVT(Context, 1);
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
+  auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4);
+
+  SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
+  SDValue Mask = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, MaskVT);
+  SDValue EL = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
+
+  SDValue VPAdd =
+      DAG->getNode(ISD::VP_ADD, DL, VInt32VT, Vector0, Vector0, Mask, EL);
+
+  using namespace SDPatternMatch;
+  EXPECT_FALSE(sd_match(VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      VPAdd, m_Node(ISD::VP_ADD, m_Value(), m_Value(), m_Value(), m_Value())));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {



More information about the llvm-commits mailing list