[llvm] 70f3863 - [DAG][PatternMatch] Add support for matchers with flags; NFC
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Sun Aug 18 15:38:12 PDT 2024
Author: Noah Goldstein
Date: 2024-08-18T15:37:56-07:00
New Revision: 70f3863b5f30e856278f399b068a30bc4d5d16c2
URL: https://github.com/llvm/llvm-project/commit/70f3863b5f30e856278f399b068a30bc4d5d16c2
DIFF: https://github.com/llvm/llvm-project/commit/70f3863b5f30e856278f399b068a30bc4d5d16c2.diff
LOG: [DAG][PatternMatch] Add support for matchers with flags; NFC
Add support for matching with `SDNodeFlags` i.e `add` with `nuw`.
This patch adds helpers for `or disjoint` or `zext nneg` with the same
names as we have in IR/PatternMatch api.
Closes #103060
Added:
Modified:
llvm/include/llvm/CodeGen/SDPatternMatch.h
llvm/include/llvm/CodeGen/SelectionDAGNodes.h
llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index b1aa87ca2d3e13..92efff93f60f89 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -514,19 +514,28 @@ struct BinaryOpc_match {
unsigned Opcode;
LHS_P LHS;
RHS_P RHS;
-
- BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R)
- : Opcode(Opc), LHS(L), RHS(R) {}
+ std::optional<SDNodeFlags> Flags;
+ BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R,
+ std::optional<SDNodeFlags> Flgs = std::nullopt)
+ : Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {}
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N, Ctx);
assert(EO.Size == 2);
- return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
- RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
- (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
- RHS.match(Ctx, N->getOperand(EO.FirstIndex)));
+ if (!((LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
+ RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
+ (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
+ RHS.match(Ctx, N->getOperand(EO.FirstIndex)))))
+ return false;
+
+ if (!Flags.has_value())
+ return true;
+
+ SDNodeFlags TmpFlags = *Flags;
+ TmpFlags.intersectWith(N->getFlags());
+ return TmpFlags == *Flags;
}
return false;
@@ -581,6 +590,19 @@ inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
}
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
+ const RHS &R) {
+ SDNodeFlags Flags;
+ Flags.setDisjoint(true);
+ return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, Flags);
+}
+
+template <typename LHS, typename RHS>
+inline auto m_AddLike(const LHS &L, const RHS &R) {
+ return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R));
+}
+
template <typename LHS, typename RHS>
inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
@@ -667,15 +689,24 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
unsigned Opcode;
Opnd_P Opnd;
-
- UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {}
+ std::optional<SDNodeFlags> Flags;
+ UnaryOpc_match(unsigned Opc, const Opnd_P &Op,
+ std::optional<SDNodeFlags> Flgs = std::nullopt)
+ : Opcode(Opc), Opnd(Op), Flags(Flgs) {}
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N, Ctx);
assert(EO.Size == 1);
- return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
+ if (!Opnd.match(Ctx, N->getOperand(EO.FirstIndex)))
+ return false;
+ if (!Flags.has_value())
+ return true;
+
+ SDNodeFlags TmpFlags = *Flags;
+ TmpFlags.intersectWith(N->getFlags());
+ return TmpFlags == *Flags;
}
return false;
@@ -701,6 +732,13 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
}
+template <typename Opnd>
+inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
+ SDNodeFlags Flags;
+ Flags.setNonNeg(true);
+ return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, Flags);
+}
+
template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
}
@@ -725,6 +763,10 @@ template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) {
return m_AnyOf(m_SExt(Op), Op);
}
+template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) {
+ return m_AnyOf(m_SExt(Op), m_NNegZExt(Op));
+}
+
/// Match a aext or identity
/// Allows to peek through optional extensions
template <typename Opnd>
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 2f36c2e86b1c3a..88549d9c9a2858 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -452,6 +452,20 @@ struct SDNodeFlags {
bool hasNoFPExcept() const { return NoFPExcept; }
bool hasUnpredictable() const { return Unpredictable; }
+ bool operator==(const SDNodeFlags &Other) const {
+ return NoUnsignedWrap == Other.NoUnsignedWrap &&
+ NoSignedWrap == Other.NoSignedWrap && Exact == Other.Exact &&
+ Disjoint == Other.Disjoint && NonNeg == Other.NonNeg &&
+ NoNaNs == Other.NoNaNs && NoInfs == Other.NoInfs &&
+ NoSignedZeros == Other.NoSignedZeros &&
+ AllowReciprocal == Other.AllowReciprocal &&
+ AllowContract == Other.AllowContract &&
+ ApproximateFuncs == Other.ApproximateFuncs &&
+ AllowReassociation == Other.AllowReassociation &&
+ NoFPExcept == Other.NoFPExcept &&
+ Unpredictable == Other.Unpredictable;
+ }
+
/// Clear any flags in this flag set that aren't also set in Flags. All
/// flags will be cleared if Flags are undefined.
void intersectWith(const SDNodeFlags Flags) {
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index c04fc5621ab499..e66584b81bba25 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -185,6 +185,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
+ SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
@@ -192,6 +193,9 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0);
SDValue Or = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
+ SDNodeFlags DisFlags;
+ DisFlags.setDisjoint(true);
+ SDValue DisOr = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op3, DisFlags);
SDValue SMax = DAG->getNode(ISD::SMAX, DL, Int32VT, Op0, Op1);
SDValue SMin = DAG->getNode(ISD::SMIN, DL, Int32VT, Op1, Op0);
SDValue UMax = DAG->getNode(ISD::UMAX, DL, Int32VT, Op0, Op1);
@@ -205,6 +209,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(Add, m_AddLike(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add)))));
EXPECT_TRUE(
@@ -217,6 +222,12 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
EXPECT_TRUE(sd_match(Xor, m_Xor(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Or, m_c_BinOp(ISD::OR, m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value())));
+ EXPECT_FALSE(sd_match(Or, m_DisjointOr(m_Value(), m_Value())));
+
+ EXPECT_TRUE(sd_match(DisOr, m_Or(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(DisOr, m_DisjointOr(m_Value(), m_Value())));
+ EXPECT_FALSE(sd_match(DisOr, m_Add(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(DisOr, m_AddLike(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(SMax, m_c_BinOp(ISD::SMAX, m_Value(), m_Value())));
EXPECT_TRUE(sd_match(SMax, m_SMax(m_Value(), m_Value())));
@@ -242,9 +253,14 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
- SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, FloatVT);
+ SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, FloatVT);
+ SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0);
+ SDNodeFlags NNegFlags;
+ NNegFlags.setNonNeg(true);
+ SDValue ZExtNNeg =
+ DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op3, NNegFlags);
SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0);
SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1);
@@ -260,6 +276,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
using namespace SDPatternMatch;
EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value())));
EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value())));
+ EXPECT_TRUE(sd_match(SExt, m_SExtLike(m_Value())));
+ ASSERT_TRUE(ZExtNNeg->getFlags().hasNonNeg());
+ EXPECT_FALSE(sd_match(ZExtNNeg, m_SExt(m_Value())));
+ EXPECT_TRUE(sd_match(ZExtNNeg, m_NNegZExt(m_Value())));
+ EXPECT_FALSE(sd_match(ZExt, m_NNegZExt(m_Value())));
+ EXPECT_TRUE(sd_match(ZExtNNeg, m_SExtLike(m_Value())));
+ EXPECT_FALSE(sd_match(ZExt, m_SExtLike(m_Value())));
EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1))));
EXPECT_TRUE(sd_match(Neg, m_Neg(m_Value())));
More information about the llvm-commits
mailing list