[llvm] 0638e22 - [SDPatternMatch] Add m_CondCode, m_NoneOf, and some SExt improvements (#90762)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 2 08:56:47 PDT 2024
Author: Min-Yih Hsu
Date: 2024-05-02T08:56:42-07:00
New Revision: 0638e222f363e041ffe3c4a0f371f987ab92a03d
URL: https://github.com/llvm/llvm-project/commit/0638e222f363e041ffe3c4a0f371f987ab92a03d
DIFF: https://github.com/llvm/llvm-project/commit/0638e222f363e041ffe3c4a0f371f987ab92a03d.diff
LOG: [SDPatternMatch] Add m_CondCode, m_NoneOf, and some SExt improvements (#90762)
- Add m_CondCode to match the ISD::CondCode value from CondCodeSDNode
- Add m_NoneOf combinator
- m_SExt now recognizes sext_inreg
Added:
Modified:
llvm/include/llvm/CodeGen/SDPatternMatch.h
llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index c581eb7a60aac9..f34204c549b68e 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -358,6 +358,24 @@ struct Or<Pred, Preds...> : Or<Preds...> {
}
};
+template <typename Pred> struct Not {
+ Pred P;
+
+ explicit Not(const Pred &P) : P(P) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ return !P.match(Ctx, N);
+ }
+};
+// Explicit deduction guide.
+template <typename Pred> Not(const Pred &P) -> Not<Pred>;
+
+/// Match if the inner pattern does NOT match.
+template <typename Pred> inline Not<Pred> m_Unless(const Pred &P) {
+ return Not{P};
+}
+
template <typename... Preds> And<Preds...> m_AllOf(Preds &&...preds) {
return And<Preds...>(std::forward<Preds>(preds)...);
}
@@ -366,6 +384,10 @@ template <typename... Preds> Or<Preds...> m_AnyOf(Preds &&...preds) {
return Or<Preds...>(std::forward<Preds>(preds)...);
}
+template <typename... Preds> auto m_NoneOf(Preds &&...preds) {
+ return m_Unless(m_AnyOf(std::forward<Preds>(preds)...));
+}
+
// === Generic node matching ===
template <unsigned OpIdx, typename... OpndPreds> struct Operands_match {
template <typename MatchContext>
@@ -620,8 +642,10 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
}
-template <typename Opnd> inline UnaryOpc_match<Opnd> m_SExt(const Opnd &Op) {
- return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
+template <typename Opnd> inline auto m_SExt(Opnd &&Op) {
+ return m_AnyOf(
+ UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op),
+ m_Node(ISD::SIGN_EXTEND_INREG, std::forward<Opnd>(Op), m_Value()));
}
template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) {
@@ -634,18 +658,14 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) {
/// Match a zext or identity
/// Allows to peek through optional extensions
-template <typename Opnd>
-inline Or<UnaryOpc_match<Opnd>, Opnd> m_ZExtOrSelf(Opnd &&Op) {
- return Or<UnaryOpc_match<Opnd>, Opnd>(m_ZExt(std::forward<Opnd>(Op)),
- std::forward<Opnd>(Op));
+template <typename Opnd> inline auto m_ZExtOrSelf(Opnd &&Op) {
+ return m_AnyOf(m_ZExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
}
/// Match a sext or identity
/// Allows to peek through optional extensions
-template <typename Opnd>
-inline Or<UnaryOpc_match<Opnd>, Opnd> m_SExtOrSelf(Opnd &&Op) {
- return Or<UnaryOpc_match<Opnd>, Opnd>(m_SExt(std::forward<Opnd>(Op)),
- std::forward<Opnd>(Op));
+template <typename Opnd> inline auto m_SExtOrSelf(Opnd &&Op) {
+ return m_AnyOf(m_SExt(std::forward<Opnd>(Op)), std::forward<Opnd>(Op));
}
/// Match a aext or identity
@@ -768,6 +788,39 @@ inline auto m_False() {
m_Value()};
}
+struct CondCode_match {
+ std::optional<ISD::CondCode> CCToMatch;
+ ISD::CondCode *BindCC = nullptr;
+
+ explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC) {}
+
+ explicit CondCode_match(ISD::CondCode *CC) : BindCC(CC) {}
+
+ template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+ if (auto *CC = dyn_cast<CondCodeSDNode>(N.getNode())) {
+ if (CCToMatch && *CCToMatch != CC->get())
+ return false;
+
+ if (BindCC)
+ *BindCC = CC->get();
+ return true;
+ }
+
+ return false;
+ }
+};
+
+/// Match any conditional code SDNode.
+inline CondCode_match m_CondCode() { return CondCode_match(nullptr); }
+/// Match any conditional code SDNode and return its ISD::CondCode value.
+inline CondCode_match m_CondCode(ISD::CondCode &CC) {
+ return CondCode_match(&CC);
+}
+/// Match a conditional code SDNode with a specific ISD::CondCode.
+inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
+ return CondCode_match(CC);
+}
+
/// Match a negate as a sub(0, v)
template <typename ValTy>
inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index a7112cfac63de5..24930b965f1def 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -217,6 +217,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
SDValue Zero = DAG->getConstant(0, DL, Int32VT);
SDValue One = DAG->getConstant(1, DL, Int32VT);
SDValue AllOnes = DAG->getConstant(APInt::getAllOnes(32), DL, Int32VT);
+ SDValue SetCC = DAG->getSetCC(DL, Int32VT, Arg0, Const3, ISD::SETULT);
using namespace SDPatternMatch;
EXPECT_TRUE(sd_match(Const87, m_ConstInt()));
@@ -233,6 +234,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchConstants) {
EXPECT_TRUE(sd_match(Zero, DAG.get(), m_False()));
EXPECT_TRUE(sd_match(One, DAG.get(), m_True()));
EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_True()));
+
+ ISD::CondCode CC;
+ EXPECT_TRUE(sd_match(
+ SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(), m_CondCode(CC))));
+ EXPECT_EQ(CC, ISD::SETULT);
+ EXPECT_TRUE(sd_match(SetCC, m_Node(ISD::SETCC, m_Value(), m_Value(),
+ m_SpecificCondCode(ISD::SETULT))));
}
TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
@@ -249,6 +257,7 @@ TEST_F(SelectionDAGPatternMatchTest, patternCombinators) {
EXPECT_TRUE(sd_match(
Sub, m_AnyOf(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
EXPECT_TRUE(sd_match(Add, m_AllOf(m_Opc(ISD::ADD), m_OneUse())));
+ EXPECT_TRUE(sd_match(Add, m_NoneOf(m_Opc(ISD::SUB), m_Opc(ISD::MUL))));
}
TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
@@ -260,6 +269,8 @@ TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
SDValue Op64 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op32);
SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op32);
+ SDValue SExtInReg = DAG->getNode(ISD::SIGN_EXTEND_INREG, DL, Int64VT, Op64,
+ DAG->getValueType(Int32VT));
SDValue AExt = DAG->getNode(ISD::ANY_EXTEND, DL, Int64VT, Op32);
SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op64);
@@ -273,6 +284,8 @@ TEST_F(SelectionDAGPatternMatchTest, optionalResizing) {
EXPECT_TRUE(A == Op64);
EXPECT_TRUE(sd_match(SExt, m_SExtOrSelf(m_Value(A))));
EXPECT_TRUE(A == Op32);
+ EXPECT_TRUE(sd_match(SExtInReg, m_SExtOrSelf(m_Value(A))));
+ EXPECT_TRUE(A == Op64);
EXPECT_TRUE(sd_match(Op32, m_AExtOrSelf(m_Value(A))));
EXPECT_TRUE(A == Op32);
EXPECT_TRUE(sd_match(AExt, m_AExtOrSelf(m_Value(A))));
More information about the llvm-commits
mailing list