[llvm] [DAG] SDPatternMatch - Fix m_Reassociatable mismatching (PR #170061)
Artur Bermond Torres via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 4 08:13:16 PST 2025
https://github.com/bermondd updated https://github.com/llvm/llvm-project/pull/170061
>From ab6d06a1b3c810cbaa1d16cd46267a56c03fb3ce Mon Sep 17 00:00:00 2001
From: Artur Bermond Torres <41002679+bermondd at users.noreply.github.com>
Date: Sun, 30 Nov 2025 19:36:53 -0300
Subject: [PATCH 1/2] [DAG] SDPatternMatch - Fix m_Reassociatable mismatching
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 41 ++++++++++---------
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 10 +++++
2 files changed, 32 insertions(+), 19 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index daafd3fc9d825..435f340c5b9c9 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1315,19 +1315,12 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
if (Leaves.size() != NumPatterns)
return false;
- // Matches[I][J] == true iff sd_context_match(Leaves[I], Ctx,
- // std::get<J>(Patterns)) == true
- std::array<SmallBitVector, NumPatterns> Matches;
- for (size_t I = 0; I != NumPatterns; I++) {
- std::apply(
- [&](auto &...P) {
- (Matches[I].push_back(sd_context_match(Leaves[I], Ctx, P)), ...);
- },
- Patterns);
- }
-
SmallBitVector Used(NumPatterns);
- return reassociatableMatchHelper(Matches, Used);
+ return std::apply(
+ [&](auto &...P) -> bool {
+ return reassociatableMatchHelper(Ctx, Leaves, Used, P...);
+ },
+ Patterns);
}
void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
@@ -1339,21 +1332,31 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
}
}
+ // Searchs for a matching leaf for every sub-pattern.
+ template <typename MatchContext, typename PatternHd, typename... PatternTl>
[[nodiscard]] inline bool
- reassociatableMatchHelper(ArrayRef<SmallBitVector> Matches,
- SmallBitVector &Used, size_t Curr = 0) {
- if (Curr == Matches.size())
- return true;
- for (size_t Match = 0, N = Matches[Curr].size(); Match < N; Match++) {
- if (!Matches[Curr][Match] || Used[Match])
+ reassociatableMatchHelper(const MatchContext &Ctx,
+ SmallVector<SDValue> &Leaves, SmallBitVector &Used,
+ PatternHd &HeadPattern,
+ PatternTl &...TailPatterns) {
+ for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
+ if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern)))
continue;
Used[Match] = true;
- if (reassociatableMatchHelper(Matches, Used, Curr + 1))
+ if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...))
return true;
Used[Match] = false;
}
return false;
}
+
+ template <typename MatchContext>
+ [[nodiscard]] inline bool
+ reassociatableMatchHelper(const MatchContext &Ctx,
+ SmallVector<SDValue> &Leaves,
+ SmallBitVector &Used) {
+ return true;
+ }
};
template <typename... PatternTs>
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index c32ceee73472d..f071b4133e7dc 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -842,6 +842,16 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
EXPECT_TRUE(sd_match(
MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
+ // (Op0 + Op1) + Op0 binds correctly, allowing commutation
+ SDValue ADD010 = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, Op0);
+ SDValue A, B;
+ EXPECT_TRUE(sd_match(
+ ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(A))));
+ EXPECT_TRUE(sd_match(
+ ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(B))));
+ EXPECT_FALSE(sd_match(
+ ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(A), m_Deferred(A))));
+
// Op0 * (Op1 * (Op2 * Op3))
SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op1, MUL23);
SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
>From d989ca18b25d0e264cc6b8e61f5f07ea3300bf86 Mon Sep 17 00:00:00 2001
From: Artur Bermond Torres <git-arturbtorres at proton.me>
Date: Thu, 4 Dec 2025 13:10:36 -0300
Subject: [PATCH 2/2] [DAG] SDPatternMatch - Changed SmallVector to ArrayRef
and added more tests
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 4 ++--
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 12 ++++++++++++
2 files changed, 14 insertions(+), 2 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 435f340c5b9c9..9431d542f3267 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1336,7 +1336,7 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
template <typename MatchContext, typename PatternHd, typename... PatternTl>
[[nodiscard]] inline bool
reassociatableMatchHelper(const MatchContext &Ctx,
- SmallVector<SDValue> &Leaves, SmallBitVector &Used,
+ ArrayRef<SDValue> Leaves, SmallBitVector &Used,
PatternHd &HeadPattern,
PatternTl &...TailPatterns) {
for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
@@ -1353,7 +1353,7 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
template <typename MatchContext>
[[nodiscard]] inline bool
reassociatableMatchHelper(const MatchContext &Ctx,
- SmallVector<SDValue> &Leaves,
+ ArrayRef<SDValue> Leaves,
SmallBitVector &Used) {
return true;
}
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index f071b4133e7dc..69e48758b6be5 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -847,8 +847,20 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
SDValue A, B;
EXPECT_TRUE(sd_match(
ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(A))));
+ EXPECT_TRUE(sd_match(Op0, m_Deferred(A)));
+ EXPECT_TRUE(sd_match(Op1, m_Deferred(B)));
+ EXPECT_FALSE(sd_match(Op0, m_Deferred(B)));
+ EXPECT_FALSE(sd_match(Op1, m_Deferred(A)));
+ A.setNode(nullptr);
+ B.setNode(nullptr);
EXPECT_TRUE(sd_match(
ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(B))));
+ EXPECT_TRUE(sd_match(Op0, m_Deferred(B)));
+ EXPECT_TRUE(sd_match(Op1, m_Deferred(A)));
+ EXPECT_FALSE(sd_match(Op0, m_Deferred(A)));
+ EXPECT_FALSE(sd_match(Op1, m_Deferred(B)));
+ A.setNode(nullptr);
+ B.setNode(nullptr);
EXPECT_FALSE(sd_match(
ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(A), m_Deferred(A))));
More information about the llvm-commits
mailing list