[llvm] [DAG][PatternMatch] Add support for matchers with flags; NFC (PR #103060)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 13 06:20:53 PDT 2024


https://github.com/goldsteinn created https://github.com/llvm/llvm-project/pull/103060

Add support for matching with `SDNodeFlags` i.e `add` with `nuw`.

This patch adds helpers for `or disjoint` or `zext nneg` with the same
names as we have in IR/PatternMatch api.


>From 29b25280c964ca1f22d790466191e8951a3f01ed Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 13 Aug 2024 21:18:45 +0800
Subject: [PATCH] [DAG][PatternMatch] Add support for matchers with flags; NFC

Add support for matching with `SDNodeFlags` i.e `add` with `nuw`.

This patch adds helpers for `or disjoint` or `zext nneg` with the same
names as we have in IR/PatternMatch api.
---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 62 ++++++++++++++++---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 14 +++++
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  | 23 +++++++
 3 files changed, 89 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 96ece1559bc437..adeaf2fabd39e0 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -508,19 +508,28 @@ struct BinaryOpc_match {
   unsigned Opcode;
   LHS_P LHS;
   RHS_P RHS;
-
-  BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R)
-      : Opcode(Opc), LHS(L), RHS(R) {}
+  std::optional<SDNodeFlags> Flags;
+  BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R,
+                  std::optional<SDNodeFlags> Flgs = std::nullopt)
+      : Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
       EffectiveOperands<ExcludeChain> EO(N);
       assert(EO.Size == 2);
-      return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
-              RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
-             (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
-              RHS.match(Ctx, N->getOperand(EO.FirstIndex)));
+      if (!((LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
+             RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
+            (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
+             RHS.match(Ctx, N->getOperand(EO.FirstIndex)))))
+        return false;
+
+      if (!Flags.has_value())
+        return true;
+
+      SDNodeFlags TmpFlags = *Flags;
+      TmpFlags.intersectWith(N->getFlags());
+      return TmpFlags == *Flags;
     }
 
     return false;
@@ -575,6 +584,19 @@ inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
   return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
 }
 
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
+                                                    const RHS &R) {
+  SDNodeFlags Flags{};
+  Flags.setDisjoint(true);
+  return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, Flags);
+}
+
+template <typename LHS, typename RHS>
+inline auto m_AddLike(const LHS &L, const RHS &R) {
+  return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R));
+}
+
 template <typename LHS, typename RHS>
 inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
   return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
@@ -661,15 +683,24 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
 template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
   unsigned Opcode;
   Opnd_P Opnd;
-
-  UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {}
+  std::optional<SDNodeFlags> Flags;
+  UnaryOpc_match(unsigned Opc, const Opnd_P &Op,
+                 std::optional<SDNodeFlags> Flgs = std::nullopt)
+      : Opcode(Opc), Opnd(Op), Flags(Flgs) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
       EffectiveOperands<ExcludeChain> EO(N);
       assert(EO.Size == 1);
-      return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
+      if (!Opnd.match(Ctx, N->getOperand(EO.FirstIndex)))
+        return false;
+      if (!Flags.has_value())
+        return true;
+
+      SDNodeFlags TmpFlags = *Flags;
+      TmpFlags.intersectWith(N->getFlags());
+      return TmpFlags == *Flags;
     }
 
     return false;
@@ -695,6 +726,13 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
   return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
 }
 
+template <typename Opnd>
+inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
+  SDNodeFlags Flags{};
+  Flags.setNonNeg(true);
+  return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, Flags);
+}
+
 template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
   return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
 }
@@ -719,6 +757,10 @@ template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) {
   return m_AnyOf(m_SExt(Op), Op);
 }
 
+template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) {
+  return m_AnyOf(m_SExt(Op), m_NNegZExt(Op));
+}
+
 /// Match a aext or identity
 /// Allows to peek through optional extensions
 template <typename Opnd>
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 2f36c2e86b1c3a..7837a5f12214bb 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -452,6 +452,20 @@ struct SDNodeFlags {
   bool hasNoFPExcept() const { return NoFPExcept; }
   bool hasUnpredictable() const { return Unpredictable; }
 
+  bool operator==(const SDNodeFlags &other) const {
+    return NoUnsignedWrap == other.NoUnsignedWrap &&
+           NoSignedWrap == other.NoSignedWrap && Exact == other.Exact &&
+           Disjoint == other.Disjoint && NonNeg == other.NonNeg &&
+           NoNaNs == other.NoNaNs && NoInfs == other.NoInfs &&
+           NoSignedZeros == other.NoSignedZeros &&
+           AllowReciprocal == other.AllowReciprocal &&
+           AllowContract == other.AllowContract &&
+           ApproximateFuncs == other.ApproximateFuncs &&
+           AllowReassociation == other.AllowReassociation &&
+           NoFPExcept == other.NoFPExcept &&
+           Unpredictable == other.Unpredictable;
+  }
+
   /// Clear any flags in this flag set that aren't also set in Flags. All
   /// flags will be cleared if Flags are undefined.
   void intersectWith(const SDNodeFlags Flags) {
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 074247e6e7d184..6db31990968afa 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -185,6 +185,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
   SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
+  SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
 
   SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
   SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
@@ -192,6 +193,9 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
   SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0);
   SDValue Or  = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
+  SDNodeFlags DisFlags{};
+  DisFlags.setDisjoint(true);
+  SDValue DisOr = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op3, DisFlags);
   SDValue SMax = DAG->getNode(ISD::SMAX, DL, Int32VT, Op0, Op1);
   SDValue SMin = DAG->getNode(ISD::SMIN, DL, Int32VT, Op1, Op0);
   SDValue UMax = DAG->getNode(ISD::UMAX, DL, Int32VT, Op0, Op1);
@@ -205,6 +209,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(Add, m_AddLike(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(
       Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add)))));
   EXPECT_TRUE(
@@ -217,6 +222,12 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   EXPECT_TRUE(sd_match(Xor, m_Xor(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(Or, m_c_BinOp(ISD::OR, m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(Or, m_DisjointOr(m_Value(), m_Value())));
+
+  EXPECT_TRUE(sd_match(DisOr, m_Or(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(DisOr, m_DisjointOr(m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(DisOr, m_Add(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(DisOr, m_AddLike(m_Value(), m_Value())));
 
   EXPECT_TRUE(sd_match(SMax, m_c_BinOp(ISD::SMAX, m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(SMax, m_SMax(m_Value(), m_Value())));
@@ -241,8 +252,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
 
   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
+  SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
 
   SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0);
+  SDNodeFlags NNegFlags{};
+  NNegFlags.setNonNeg(true);
+  SDValue ZExtNNeg =
+      DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op2, NNegFlags);
   SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0);
   SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1);
 
@@ -255,6 +271,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
   using namespace SDPatternMatch;
   EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value())));
   EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value())));
+  EXPECT_TRUE(sd_match(SExt, m_SExtLike(m_Value())));
+  ASSERT_TRUE(ZExtNNeg->getFlags().hasNonNeg());
+  EXPECT_FALSE(sd_match(ZExtNNeg, m_SExt(m_Value())));
+  EXPECT_TRUE(sd_match(ZExtNNeg, m_NNegZExt(m_Value())));
+  EXPECT_FALSE(sd_match(ZExt, m_NNegZExt(m_Value())));
+  EXPECT_TRUE(sd_match(ZExtNNeg, m_SExtLike(m_Value())));
+  EXPECT_FALSE(sd_match(ZExt, m_SExtLike(m_Value())));
   EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1))));
 
   EXPECT_TRUE(sd_match(Neg, m_Neg(m_Value())));



More information about the llvm-commits mailing list