[llvm] 0638e22 - [SDPatternMatch] Add m_CondCode, m_NoneOf, and some SExt improvements (#90762)

via llvm-commits llvm-commits at lists.llvm.org
Thu May 2 08:56:47 PDT 2024


Author: Min-Yih Hsu
Date: 2024-05-02T08:56:42-07:00
New Revision: 0638e222f363e041ffe3c4a0f371f987ab92a03d

URL: https://github.com/llvm/llvm-project/commit/0638e222f363e041ffe3c4a0f371f987ab92a03d
DIFF: https://github.com/llvm/llvm-project/commit/0638e222f363e041ffe3c4a0f371f987ab92a03d.diff

LOG: [SDPatternMatch] Add m_CondCode, m_NoneOf, and some SExt improvements (#90762)

  - Add m_CondCode to match the ISD::CondCode value from CondCodeSDNode
  - Add m_NoneOf combinator
  - m_SExt now recognizes sext_inreg

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SDPatternMatch.h
    llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index c581eb7a60aac9..f34204c549b68e 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -358,6 +358,24 @@ struct Or<Pred, Preds...> : Or<Preds...> {
   }
 };
 
+template <typename Pred> struct Not {
+  Pred P;
+
+  explicit Not(const Pred &P) : P(P) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    return !P.match(Ctx, N);
+  }
+};
+// Explicit deduction guide.
+template <typename Pred> Not(const Pred &P) -> Not<Pred>;
+
+/// Match if the inner pattern does NOT match.
+template <typename Pred> inline Not<Pred> m_Unless(const Pred &P) {
+  return Not{P};
+}
+
 template <typename... Preds> And<Preds...> m_AllOf(Preds &&...preds) {
   return And<Preds...>(std::forward<Preds>(preds)...);
 }
@@ -366,6 +384,10 @@ template <typename... Preds> Or<Preds...> m_AnyOf(Preds &&...preds) {
   return Or<Preds...>(std::forward<Preds>(preds)...);
 }
 
+template <typename... Preds> auto m_NoneOf(Preds &&...preds) {
+  return m_Unless(m_AnyOf(std::forward<Preds>(preds)...));
+}
+
 // === Generic node matching ===
 template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
   template <typename MatchContext>
@@ -620,8 +642,10 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
   return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
 }
 
-template <typename Opnd> inline UnaryOpc_match<Opnd> m_SExt(const Opnd &Op) {
-  return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
+template <typename Opnd> inline auto m_SExt(Opnd &&Op) {
+  return m_AnyOf(
+      UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op),
+      m_Node(ISD::SIGN_EXTEND_INREG, std::forward<Opnd>(Op), m_Value()));
 }
 
 template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) {
@@ -634,18 +658,14 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {
 
 /// Match a zext or identity
 /// Allows to peek through optional extensions
-template <typename Opnd>
-inline Or<UnaryOpc_match<Opnd>, Opnd> m_ZExtOrSelf(Opnd &&Op) {
-  return Or<UnaryOpc_match<Opnd>, Opnd>(m_ZExt(std::forward<Opnd>(Op)),
-                                        std::forward<Opnd>(Op));
+template <typename Opnd> inline auto m_ZExtOrSelf(Opnd &&Op) {
+  return m_AnyOf(m_ZExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
 }
 
 /// Match a sext or identity
 /// Allows to peek through optional extensions
-template <typename Opnd>
-inline Or<UnaryOpc_match<Opnd>, Opnd> m_SExtOrSelf(Opnd &&Op) {
-  return Or<UnaryOpc_match<Opnd>, Opnd>(m_SExt(std::forward<Opnd>(Op)),
-                                        std::forward<Opnd>(Op));
+template <typename Opnd> inline auto m_SExtOrSelf(Opnd &&Op) {
+  return m_AnyOf(m_SExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
 }
 
 /// Match a aext or identity
@@ -768,6 +788,39 @@ inline auto m_False() {
       m_Value()};
 }
 
+struct CondCode_match {
+  std::optional<ISD::CondCode> CCToMatch;
+  ISD::CondCode *BindCC = nullptr;
+
+  explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC) {}
+
+  explicit CondCode_match(ISD::CondCode *CC) : BindCC(CC) {}
+
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    if (auto *CC = dyn_cast<CondCodeSDNode>(N.getNode())) {
+      if (CCToMatch && *CCToMatch != CC->get())
+        return false;
+
+      if (BindCC)
+        *BindCC = CC->get();
+      return true;
+    }
+
+    return false;
+  }
+};
+
+/// Match any conditional code SDNode.
+inline CondCode_match m_CondCode() { return CondCode_match(nullptr); }
+/// Match any conditional code SDNode and return its ISD::CondCode value.
+inline CondCode_match m_CondCode(ISD::CondCode &CC) {
+  return CondCode_match(&CC);
+}
+/// Match a conditional code SDNode with a specific ISD::CondCode.
+inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
+  return CondCode_match(CC);
+}
+
 /// Match a negate as a sub(0, v)
 template <typename ValTy>
 inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {

diff  --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index a7112cfac63de5..24930b965f1def 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -217,6 +217,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
   SDValue Zero = DAG->getConstant(0, DL, Int32VT);
   SDValue One = DAG->getConstant(1, DL, Int32VT);
   SDValue AllOnes = DAG->getConstant(APInt::getAllOnes(32), DL, Int32VT);
+  SDValue SetCC = DAG->getSetCC(DL, Int32VT, Arg0, Const3, ISD::SETULT);
 
   using namespace SDPatternMatch;
   EXPECT_TRUE(sd_match(Const87, m_ConstInt()));
@@ -233,6 +234,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
   EXPECT_TRUE(sd_match(Zero, DAG.get(), m_False()));
   EXPECT_TRUE(sd_match(One, DAG.get(), m_True()));
   EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_True()));
+
+  ISD::CondCode CC;
+  EXPECT_TRUE(sd_match(
+      SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(), m_CondCode(CC))));
+  EXPECT_EQ(CC, ISD::SETULT);
+  EXPECT_TRUE(sd_match(SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(),
+                                     m_SpecificCondCode(ISD::SETULT))));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
@@ -249,6 +257,7 @@ TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
   EXPECT_TRUE(sd_match(
       Sub, m_AnyOf(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
   EXPECT_TRUE(sd_match(Add, m_AllOf(m_Opc(ISD::ADD), m_OneUse())));
+  EXPECT_TRUE(sd_match(Add, m_NoneOf(m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
 }
 
 TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
@@ -260,6 +269,8 @@ TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
   SDValue Op64 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
   SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op32);
   SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op32);
+  SDValue SExtInReg = DAG->getNode(ISD::SIGN_EXTEND_INREG, DL, Int64VT, Op64,
+                                   DAG->getValueType(Int32VT));
   SDValue AExt = DAG->getNode(ISD::ANY_EXTEND, DL, Int64VT, Op32);
   SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op64);
 
@@ -273,6 +284,8 @@ TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
   EXPECT_TRUE(A == Op64);
   EXPECT_TRUE(sd_match(SExt, m_SExtOrSelf(m_Value(A))));
   EXPECT_TRUE(A == Op32);
+  EXPECT_TRUE(sd_match(SExtInReg, m_SExtOrSelf(m_Value(A))));
+  EXPECT_TRUE(A == Op64);
   EXPECT_TRUE(sd_match(Op32, m_AExtOrSelf(m_Value(A))));
   EXPECT_TRUE(A == Op32);
   EXPECT_TRUE(sd_match(AExt, m_AExtOrSelf(m_Value(A))));


        


More information about the llvm-commits mailing list