[llvm] [SDPatternMatch] Add m_CondCode, m_NoneOf, and some SExt improvements (PR #90762)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Wed May 1 15:33:11 PDT 2024


https://github.com/mshockwave updated https://github.com/llvm/llvm-project/pull/90762

>From b4e00e91b5f3178328f2bfae2ca71fbf84b0701a Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 1 May 2024 11:53:25 -0700
Subject: [PATCH 1/4] [SDPatternMatch] Add m_CondCode, m_NoneOf, and some
 ZExt/SExt improvements

  - Add m_CondCode to match the ISD::CondCode value from CondCodeSDNode
  - Add m_NoneOf combinator
  - m_ZExt now recognizes (and X, ~0); m_SExt now recognizes sext_inreg
---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 88 +++++++++++++++----
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  | 17 ++++
 2 files changed, 87 insertions(+), 18 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 4cc7bb9c3b55a9..822086c3a5ea12 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -358,6 +358,19 @@ struct Or<Pred, Preds...> : Or<Preds...> {
   }
 };
 
+template <typename Pred> struct Not {
+  Pred P;
+
+  explicit Not(const Pred &P) : P(P) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    return !P.match(Ctx, N);
+  }
+};
+// Explicit deduction guide.
+template <typename Pred> Not(const Pred &P) -> Not<Pred>;
+
 template <typename... Preds> And<Preds...> m_AllOf(Preds &&...preds) {
   return And<Preds...>(std::forward<Preds>(preds)...);
 }
@@ -366,6 +379,10 @@ template <typename... Preds> Or<Preds...> m_AnyOf(Preds &&...preds) {
   return Or<Preds...>(std::forward<Preds>(preds)...);
 }
 
+template <typename... Preds> auto m_NoneOf(Preds &&...preds) {
+  return Not{m_AnyOf(std::forward<Preds>(preds)...)};
+}
+
 // === Generic node matching ===
 template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
   template <typename MatchContext>
@@ -616,12 +633,10 @@ inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc,
   return UnaryOpc_match<Opnd, true>(Opc, Op);
 }
 
-template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
-  return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
-}
-
-template <typename Opnd> inline UnaryOpc_match<Opnd> m_SExt(const Opnd &Op) {
-  return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
+template <typename Opnd> inline auto m_SExt(Opnd &&Op) {
+  return m_AnyOf(
+      UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op),
+      m_Node(ISD::SIGN_EXTEND_INREG, std::forward<Opnd>(Op), m_Value()));
 }
 
 template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) {
@@ -632,20 +647,10 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {
   return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op);
 }
 
-/// Match a zext or identity
-/// Allows to peek through optional extensions
-template <typename Opnd>
-inline Or<UnaryOpc_match<Opnd>, Opnd> m_ZExtOrSelf(Opnd &&Op) {
-  return Or<UnaryOpc_match<Opnd>, Opnd>(m_ZExt(std::forward<Opnd>(Op)),
-                                        std::forward<Opnd>(Op));
-}
-
 /// Match a sext or identity
 /// Allows to peek through optional extensions
-template <typename Opnd>
-inline Or<UnaryOpc_match<Opnd>, Opnd> m_SExtOrSelf(Opnd &&Op) {
-  return Or<UnaryOpc_match<Opnd>, Opnd>(m_SExt(std::forward<Opnd>(Op)),
-                                        std::forward<Opnd>(Op));
+template <typename Opnd> inline auto m_SExtOrSelf(Opnd &&Op) {
+  return m_AnyOf(m_SExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
 }
 
 /// Match a aext or identity
@@ -718,6 +723,20 @@ inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
 inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
 inline SpecificInt_match m_AllOnes() { return m_SpecificInt(~0U); }
 
+// FIXME: We probably ought to move constant matchers before
+// most of the others so that m_ZExt/m_ZExtOrSelf can be
+// with other extension matchers.
+template <typename Opnd> inline auto m_ZExt(const Opnd &Op) {
+  return m_AnyOf(UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op),
+                 m_And(Op, m_AllOnes()));
+}
+
+/// Match a zext or identity
+/// Allows to peek through optional extensions
+template <typename Opnd> inline auto m_ZExtOrSelf(Opnd &&Op) {
+  return m_AnyOf(m_ZExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
+}
+
 /// Match true boolean value based on the information provided by
 /// TargetLowering.
 inline auto m_True() {
@@ -758,6 +777,39 @@ inline auto m_False() {
       m_Value()};
 }
 
+struct CondCode_match {
+  std::optional<ISD::CondCode> CCToMatch;
+  ISD::CondCode *BindCC;
+
+  explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC), BindCC(nullptr) {}
+
+  explicit CondCode_match(ISD::CondCode *CC)
+      : CCToMatch(std::nullopt), BindCC(CC) {}
+
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    if (auto *CC = dyn_cast<CondCodeSDNode>(N.getNode())) {
+      if (CCToMatch && *CCToMatch != CC->get())
+        return false;
+
+      if (BindCC)
+        *BindCC = CC->get();
+      return true;
+    }
+
+    return false;
+  }
+};
+
+/// Match any conditional code SDNode.
+inline CondCode_match m_CondCode() { return CondCode_match(nullptr); }
+/// Match any conditional code SDNode and return its ISD::CondCode value.
+inline CondCode_match m_CondCode(ISD::CondCode &CC) {
+  return CondCode_match(&CC);
+}
+inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
+  return CondCode_match(CC);
+}
+
 /// Match a negate as a sub(0, v)
 template <typename ValTy>
 inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index a7112cfac63de5..bbe11aa7ecaaa7 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -217,6 +217,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
   SDValue Zero = DAG->getConstant(0, DL, Int32VT);
   SDValue One = DAG->getConstant(1, DL, Int32VT);
   SDValue AllOnes = DAG->getConstant(APInt::getAllOnes(32), DL, Int32VT);
+  SDValue SetCC = DAG->getSetCC(DL, Int32VT, Arg0, Const3, ISD::SETULT);
 
   using namespace SDPatternMatch;
   EXPECT_TRUE(sd_match(Const87, m_ConstInt()));
@@ -233,6 +234,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
   EXPECT_TRUE(sd_match(Zero, DAG.get(), m_False()));
   EXPECT_TRUE(sd_match(One, DAG.get(), m_True()));
   EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_True()));
+
+  ISD::CondCode CC;
+  EXPECT_TRUE(sd_match(
+      SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(), m_CondCode(CC))));
+  EXPECT_EQ(CC, ISD::SETULT);
+  EXPECT_TRUE(sd_match(SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(),
+                                     m_SpecificCondCode(ISD::SETULT))));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
@@ -249,6 +257,7 @@ TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
   EXPECT_TRUE(sd_match(
       Sub, m_AnyOf(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
   EXPECT_TRUE(sd_match(Add, m_AllOf(m_Opc(ISD::ADD), m_OneUse())));
+  EXPECT_TRUE(sd_match(Add, m_NoneOf(m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
@@ -259,7 +268,11 @@ TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
   SDValue Op32 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
   SDValue Op64 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
   SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op32);
+  SDValue AndZExt = DAG->getNode(ISD::AND, DL, Int64VT, Op64,
+                                 DAG->getConstant(~0U, DL, Int64VT));
   SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op32);
+  SDValue SExtInReg = DAG->getNode(ISD::SIGN_EXTEND_INREG, DL, Int64VT, Op64,
+                                   DAG->getValueType(Int32VT));
   SDValue AExt = DAG->getNode(ISD::ANY_EXTEND, DL, Int64VT, Op32);
   SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op64);
 
@@ -269,10 +282,14 @@ TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
   EXPECT_TRUE(A == Op32);
   EXPECT_TRUE(sd_match(ZExt, m_ZExtOrSelf(m_Value(A))));
   EXPECT_TRUE(A == Op32);
+  EXPECT_TRUE(sd_match(AndZExt, m_ZExtOrSelf(m_Value(A))));
+  EXPECT_TRUE(A == Op64);
   EXPECT_TRUE(sd_match(Op64, m_SExtOrSelf(m_Value(A))));
   EXPECT_TRUE(A == Op64);
   EXPECT_TRUE(sd_match(SExt, m_SExtOrSelf(m_Value(A))));
   EXPECT_TRUE(A == Op32);
+  EXPECT_TRUE(sd_match(SExtInReg, m_SExtOrSelf(m_Value(A))));
+  EXPECT_TRUE(A == Op64);
   EXPECT_TRUE(sd_match(Op32, m_AExtOrSelf(m_Value(A))));
   EXPECT_TRUE(A == Op32);
   EXPECT_TRUE(sd_match(AExt, m_AExtOrSelf(m_Value(A))));

>From c48faa5c4439fbfb762fd8b1489ab51214cb96d8 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 1 May 2024 12:06:53 -0700
Subject: [PATCH 2/4] fixup! [SDPatternMatch] Add m_CondCode, m_NoneOf, and
 some ZExt/SExt improvements

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 822086c3a5ea12..54f5030bd48fa3 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -806,6 +806,7 @@ inline CondCode_match m_CondCode() { return CondCode_match(nullptr); }
 inline CondCode_match m_CondCode(ISD::CondCode &CC) {
   return CondCode_match(&CC);
 }
+/// Match a conditional code SDNode with a specific ISD::CondCode.
 inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
   return CondCode_match(CC);
 }

>From 61ed677f2c53d4d73ae6a9e6af0ba861d8077541 Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 1 May 2024 13:24:21 -0700
Subject: [PATCH 3/4] Remove incorrect ZExt pattern

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 24 ++++++++-----------
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  |  4 ----
 2 files changed, 10 insertions(+), 18 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 54f5030bd48fa3..dc17eba856aefe 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -633,6 +633,10 @@ inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc,
   return UnaryOpc_match<Opnd, true>(Opc, Op);
 }
 
+template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
+  return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
+}
+
 template <typename Opnd> inline auto m_SExt(Opnd &&Op) {
   return m_AnyOf(
       UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op),
@@ -647,6 +651,12 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {
   return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op);
 }
 
+/// Match a zext or identity
+/// Allows to peek through optional extensions
+template <typename Opnd> inline auto m_ZExtOrSelf(Opnd &&Op) {
+  return m_AnyOf(m_ZExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
+}
+
 /// Match a sext or identity
 /// Allows to peek through optional extensions
 template <typename Opnd> inline auto m_SExtOrSelf(Opnd &&Op) {
@@ -723,20 +733,6 @@ inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
 inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
 inline SpecificInt_match m_AllOnes() { return m_SpecificInt(~0U); }
 
-// FIXME: We probably ought to move constant matchers before
-// most of the others so that m_ZExt/m_ZExtOrSelf can be
-// with other extension matchers.
-template <typename Opnd> inline auto m_ZExt(const Opnd &Op) {
-  return m_AnyOf(UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op),
-                 m_And(Op, m_AllOnes()));
-}
-
-/// Match a zext or identity
-/// Allows to peek through optional extensions
-template <typename Opnd> inline auto m_ZExtOrSelf(Opnd &&Op) {
-  return m_AnyOf(m_ZExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
-}
-
 /// Match true boolean value based on the information provided by
 /// TargetLowering.
 inline auto m_True() {
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index bbe11aa7ecaaa7..24930b965f1def 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -268,8 +268,6 @@ TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
   SDValue Op32 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
   SDValue Op64 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
   SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op32);
-  SDValue AndZExt = DAG->getNode(ISD::AND, DL, Int64VT, Op64,
-                                 DAG->getConstant(~0U, DL, Int64VT));
   SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op32);
   SDValue SExtInReg = DAG->getNode(ISD::SIGN_EXTEND_INREG, DL, Int64VT, Op64,
                                    DAG->getValueType(Int32VT));
@@ -282,8 +280,6 @@ TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
   EXPECT_TRUE(A == Op32);
   EXPECT_TRUE(sd_match(ZExt, m_ZExtOrSelf(m_Value(A))));
   EXPECT_TRUE(A == Op32);
-  EXPECT_TRUE(sd_match(AndZExt, m_ZExtOrSelf(m_Value(A))));
-  EXPECT_TRUE(A == Op64);
   EXPECT_TRUE(sd_match(Op64, m_SExtOrSelf(m_Value(A))));
   EXPECT_TRUE(A == Op64);
   EXPECT_TRUE(sd_match(SExt, m_SExtOrSelf(m_Value(A))));

>From 8664b5802e4319bc896e516ba6e08fb99d6b4c1e Mon Sep 17 00:00:00 2001
From: Min Hsu <min.hsu at sifive.com>
Date: Wed, 1 May 2024 15:32:45 -0700
Subject: [PATCH 4/4] Address review comment

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h | 6 +++++-
 1 file changed, 5 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index dc17eba856aefe..b4fd5d4be0e2e8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -371,6 +371,10 @@ template <typename Pred> struct Not {
 // Explicit deduction guide.
 template <typename Pred> Not(const Pred &P) -> Not<Pred>;
 
+template <typename Pred> inline Not<Pred> m_IsNot(const Pred &P) {
+  return Not{P};
+}
+
 template <typename... Preds> And<Preds...> m_AllOf(Preds &&...preds) {
   return And<Preds...>(std::forward<Preds>(preds)...);
 }
@@ -380,7 +384,7 @@ template <typename... Preds> Or<Preds...> m_AnyOf(Preds &&...preds) {
 }
 
 template <typename... Preds> auto m_NoneOf(Preds &&...preds) {
-  return Not{m_AnyOf(std::forward<Preds>(preds)...)};
+  return m_IsNot(m_AnyOf(std::forward<Preds>(preds)...));
 }
 
 // === Generic node matching ===



More information about the llvm-commits mailing list