[llvm] [DAG] SDPatternMatch - Add matchers for reassociatable additions with NSW/NUW flags (PR #177973)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 28 05:31:27 PST 2026
https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/177973
>From 18b9376b454666b449069d212257ea788ad8a0e8 Mon Sep 17 00:00:00 2001
From: crisiumnih <fasfasmag at proton.me>
Date: Mon, 26 Jan 2026 20:26:13 +0530
Subject: [PATCH 1/3] flags added
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 25 ++++++++++-
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 44 +++++++++++++++++++
2 files changed, 68 insertions(+), 1 deletion(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index cb209dc3482be..338357d9884b1 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1342,9 +1342,15 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
constexpr static size_t NumPatterns =
std::tuple_size_v<std::tuple<PatternTs...>>;
+ std::optional<SDNodeFlags> Flags;
+
ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
: Opcode(Opcode), Patterns(Patterns...) {}
+ ReassociatableOpc_match(unsigned Opcode, SDNodeFlags Flags,
+ const PatternTs &...Patterns)
+ : Opcode(Opcode), Patterns(Patterns...), Flags(Flags) {}
+
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
std::array<SDValue, NumPatterns> Leaves;
@@ -1362,7 +1368,8 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
bool collectLeaves(SDValue V, std::array<SDValue, NumPatterns> &Leaves,
std::size_t &LeafIdx) {
- if (V->getOpcode() == Opcode) {
+ if (V->getOpcode() == Opcode &&
+ (!Flags || (*Flags & V->getFlags()) == *Flags)) {
for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
if ((LeafIdx == NumPatterns) ||
!collectLeaves(V->getOperand(I), Leaves, LeafIdx))
@@ -1423,6 +1430,22 @@ m_ReassociatableMul(const PatternTs &...Patterns) {
return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
}
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableNSWAdd(const PatternTs &...Patterns) {
+ return ReassociatableOpc_match<PatternTs...>(ISD::ADD,
+ SDNodeFlags::NoSignedWrap,
+ Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableNUWAdd(const PatternTs &...Patterns) {
+ return ReassociatableOpc_match<PatternTs...>(ISD::ADD,
+ SDNodeFlags::NoUnsignedWrap,
+ Patterns...);
+}
+
} // namespace SDPatternMatch
} // namespace llvm
#endif
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 900d39f91e303..c052031be4cda 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -1173,3 +1173,47 @@ TEST_F(SelectionDAGPatternMatchTest, MatchSpecificNeg) {
SDValue Zero = DAG->getConstant(0, DL, Int32VT);
EXPECT_TRUE(sd_match(Zero, m_SpecificNeg(Zero)));
}
+
+TEST_F(SelectionDAGPatternMatchTest, matchReassociatableFlags) {
+ SDLoc DL;
+ auto Int32VT = EVT::getIntegerVT(Context, 32);
+
+ SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+ SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+ SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
+
+ SDNodeFlags NSWFlags;
+ NSWFlags.setNoSignedWrap(true);
+
+ // (Op0 +nsw Op1) +nsw Op2
+ SDValue Add0 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1, NSWFlags);
+ SDValue Add1 = DAG->getNode(ISD::ADD, DL, Int32VT, Add0, Op2, NSWFlags);
+
+ SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 4, Int32VT);
+ SDValue Op4 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 5, Int32VT);
+ SDValue Op5 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 6, Int32VT);
+
+ // (Op3 + Op4) +nsw Op5
+ SDValue Add2 = DAG->getNode(ISD::ADD, DL, Int32VT, Op3, Op4);
+ SDValue Add3 = DAG->getNode(ISD::ADD, DL, Int32VT, Add2, Op5, NSWFlags);
+
+ SDValue Op6 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 7, Int32VT);
+ SDValue Op7 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
+ SDValue Op8 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 9, Int32VT);
+
+ SDNodeFlags NUWFlags;
+ NUWFlags.setNoUnsignedWrap(true);
+
+ // (Op6 +nuw Op7) +nuw Op8
+ SDValue Add4 = DAG->getNode(ISD::ADD, DL, Int32VT, Op6, Op7, NUWFlags);
+ SDValue Add5 = DAG->getNode(ISD::ADD, DL, Int32VT, Add4, Op8, NUWFlags);
+
+ using namespace SDPatternMatch;
+
+ EXPECT_TRUE(sd_match(Add1, m_ReassociatableNSWAdd(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+ EXPECT_FALSE(sd_match(Add3, m_ReassociatableNSWAdd(m_Specific(Op3), m_Specific(Op4), m_Specific(Op5))));
+ EXPECT_TRUE(sd_match(Add3, m_ReassociatableNSWAdd(m_Specific(Add2), m_Specific(Op5))));
+
+ EXPECT_TRUE(sd_match(Add5, m_ReassociatableNUWAdd(m_Specific(Op6), m_Specific(Op7), m_Specific(Op8))));
+ EXPECT_FALSE(sd_match(Add1, m_ReassociatableNUWAdd(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+}
>From f1b4d8ac080c2140a07cd0b18f1a2eb55804e751 Mon Sep 17 00:00:00 2001
From: crisiumnih <fasfasmag at proton.me>
Date: Mon, 26 Jan 2026 20:35:49 +0530
Subject: [PATCH 2/3] formatting
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 10 ++++------
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 19 ++++++++++++++-----
2 files changed, 18 insertions(+), 11 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 338357d9884b1..0d08210573e12 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1433,17 +1433,15 @@ m_ReassociatableMul(const PatternTs &...Patterns) {
template <typename... PatternTs>
inline ReassociatableOpc_match<PatternTs...>
m_ReassociatableNSWAdd(const PatternTs &...Patterns) {
- return ReassociatableOpc_match<PatternTs...>(ISD::ADD,
- SDNodeFlags::NoSignedWrap,
- Patterns...);
+ return ReassociatableOpc_match<PatternTs...>(
+ ISD::ADD, SDNodeFlags::NoSignedWrap, Patterns...);
}
template <typename... PatternTs>
inline ReassociatableOpc_match<PatternTs...>
m_ReassociatableNUWAdd(const PatternTs &...Patterns) {
- return ReassociatableOpc_match<PatternTs...>(ISD::ADD,
- SDNodeFlags::NoUnsignedWrap,
- Patterns...);
+ return ReassociatableOpc_match<PatternTs...>(
+ ISD::ADD, SDNodeFlags::NoUnsignedWrap, Patterns...);
}
} // namespace SDPatternMatch
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index c052031be4cda..8fc3739d46767 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -1210,10 +1210,19 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableFlags) {
using namespace SDPatternMatch;
- EXPECT_TRUE(sd_match(Add1, m_ReassociatableNSWAdd(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
- EXPECT_FALSE(sd_match(Add3, m_ReassociatableNSWAdd(m_Specific(Op3), m_Specific(Op4), m_Specific(Op5))));
- EXPECT_TRUE(sd_match(Add3, m_ReassociatableNSWAdd(m_Specific(Add2), m_Specific(Op5))));
+ EXPECT_TRUE(
+ sd_match(Add1, m_ReassociatableNSWAdd(m_Specific(Op0), m_Specific(Op1),
+ m_Specific(Op2))));
+ EXPECT_FALSE(
+ sd_match(Add3, m_ReassociatableNSWAdd(m_Specific(Op3), m_Specific(Op4),
+ m_Specific(Op5))));
+ EXPECT_TRUE(sd_match(
+ Add3, m_ReassociatableNSWAdd(m_Specific(Add2), m_Specific(Op5))));
- EXPECT_TRUE(sd_match(Add5, m_ReassociatableNUWAdd(m_Specific(Op6), m_Specific(Op7), m_Specific(Op8))));
- EXPECT_FALSE(sd_match(Add1, m_ReassociatableNUWAdd(m_Specific(Op0), m_Specific(Op1), m_Specific(Op2))));
+ EXPECT_TRUE(
+ sd_match(Add5, m_ReassociatableNUWAdd(m_Specific(Op6), m_Specific(Op7),
+ m_Specific(Op8))));
+ EXPECT_FALSE(
+ sd_match(Add1, m_ReassociatableNUWAdd(m_Specific(Op0), m_Specific(Op1),
+ m_Specific(Op2))));
}
>From ab61b3a5444a339230fb438936fc467e5cd0a467 Mon Sep 17 00:00:00 2001
From: crisiumnih <fasfasmag at proton.me>
Date: Mon, 26 Jan 2026 22:10:43 +0530
Subject: [PATCH 3/3] use SDNodeFlags member and added tests for both flags
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 5 ++---
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 18 ++++++++++++++++++
2 files changed, 20 insertions(+), 3 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 0d08210573e12..785254c894d01 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1342,7 +1342,7 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
constexpr static size_t NumPatterns =
std::tuple_size_v<std::tuple<PatternTs...>>;
- std::optional<SDNodeFlags> Flags;
+ SDNodeFlags Flags;
ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
: Opcode(Opcode), Patterns(Patterns...) {}
@@ -1368,8 +1368,7 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
bool collectLeaves(SDValue V, std::array<SDValue, NumPatterns> &Leaves,
std::size_t &LeafIdx) {
- if (V->getOpcode() == Opcode &&
- (!Flags || (*Flags & V->getFlags()) == *Flags)) {
+ if (V->getOpcode() == Opcode && (Flags & V->getFlags()) == Flags) {
for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
if ((LeafIdx == NumPatterns) ||
!collectLeaves(V->getOperand(I), Leaves, LeafIdx))
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 8fc3739d46767..61c509608708c 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -1208,6 +1208,18 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableFlags) {
SDValue Add4 = DAG->getNode(ISD::ADD, DL, Int32VT, Op6, Op7, NUWFlags);
SDValue Add5 = DAG->getNode(ISD::ADD, DL, Int32VT, Add4, Op8, NUWFlags);
+ // (Op0 +nsw+nuw Op1) +nsw+nuw Op2
+ SDNodeFlags BothFlags;
+ BothFlags.setNoSignedWrap(true);
+ BothFlags.setNoUnsignedWrap(true);
+
+ SDValue Op9 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 10, Int32VT);
+ SDValue Op10 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 11, Int32VT);
+ SDValue Op11 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 12, Int32VT);
+
+ SDValue Add6 = DAG->getNode(ISD::ADD, DL, Int32VT, Op9, Op10, BothFlags);
+ SDValue Add7 = DAG->getNode(ISD::ADD, DL, Int32VT, Add6, Op11, BothFlags);
+
using namespace SDPatternMatch;
EXPECT_TRUE(
@@ -1225,4 +1237,10 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableFlags) {
EXPECT_FALSE(
sd_match(Add1, m_ReassociatableNUWAdd(m_Specific(Op0), m_Specific(Op1),
m_Specific(Op2))));
+ EXPECT_TRUE(
+ sd_match(Add7, m_ReassociatableNSWAdd(m_Specific(Op9), m_Specific(Op10),
+ m_Specific(Op11))));
+ EXPECT_TRUE(
+ sd_match(Add7, m_ReassociatableNUWAdd(m_Specific(Op9), m_Specific(Op10),
+ m_Specific(Op11))));
}
More information about the llvm-commits
mailing list