[llvm] [DAG] SDPatternMatch - Fix m_Reassociatable mismatching (PR #170061)

Artur Bermond Torres via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 4 08:13:16 PST 2025


https://github.com/bermondd updated https://github.com/llvm/llvm-project/pull/170061

>From ab6d06a1b3c810cbaa1d16cd46267a56c03fb3ce Mon Sep 17 00:00:00 2001
From: Artur Bermond Torres <41002679+bermondd at users.noreply.github.com>
Date: Sun, 30 Nov 2025 19:36:53 -0300
Subject: [PATCH 1/2] [DAG] SDPatternMatch - Fix m_Reassociatable mismatching

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 41 ++++++++++---------
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  | 10 +++++
 2 files changed, 32 insertions(+), 19 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index daafd3fc9d825..435f340c5b9c9 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1315,19 +1315,12 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
     if (Leaves.size() != NumPatterns)
       return false;
 
-    // Matches[I][J] == true iff sd_context_match(Leaves[I], Ctx,
-    // std::get<J>(Patterns)) == true
-    std::array<SmallBitVector, NumPatterns> Matches;
-    for (size_t I = 0; I != NumPatterns; I++) {
-      std::apply(
-          [&](auto &...P) {
-            (Matches[I].push_back(sd_context_match(Leaves[I], Ctx, P)), ...);
-          },
-          Patterns);
-    }
-
     SmallBitVector Used(NumPatterns);
-    return reassociatableMatchHelper(Matches, Used);
+    return std::apply(
+        [&](auto &...P) -> bool {
+          return reassociatableMatchHelper(Ctx, Leaves, Used, P...);
+        },
+        Patterns);
   }
 
   void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
@@ -1339,21 +1332,31 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
     }
   }
 
+  // Searchs for a matching leaf for every sub-pattern.
+  template <typename MatchContext, typename PatternHd, typename... PatternTl>
   [[nodiscard]] inline bool
-  reassociatableMatchHelper(ArrayRef<SmallBitVector> Matches,
-                            SmallBitVector &Used, size_t Curr = 0) {
-    if (Curr == Matches.size())
-      return true;
-    for (size_t Match = 0, N = Matches[Curr].size(); Match < N; Match++) {
-      if (!Matches[Curr][Match] || Used[Match])
+  reassociatableMatchHelper(const MatchContext &Ctx,
+                            SmallVector<SDValue> &Leaves, SmallBitVector &Used,
+                            PatternHd &HeadPattern,
+                            PatternTl &...TailPatterns) {
+    for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
+      if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern)))
         continue;
       Used[Match] = true;
-      if (reassociatableMatchHelper(Matches, Used, Curr + 1))
+      if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...))
         return true;
       Used[Match] = false;
     }
     return false;
   }
+
+  template <typename MatchContext>
+  [[nodiscard]] inline bool
+  reassociatableMatchHelper(const MatchContext &Ctx,
+                            SmallVector<SDValue> &Leaves,
+                            SmallBitVector &Used) {
+    return true;
+  }
 };
 
 template <typename... PatternTs>
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index c32ceee73472d..f071b4133e7dc 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -842,6 +842,16 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
   EXPECT_TRUE(sd_match(
       MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
 
+  // (Op0 + Op1) + Op0 binds correctly, allowing commutation
+  SDValue ADD010 = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, Op0);
+  SDValue A, B;
+  EXPECT_TRUE(sd_match(
+      ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(A))));
+  EXPECT_TRUE(sd_match(
+      ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(B))));
+  EXPECT_FALSE(sd_match(
+      ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(A), m_Deferred(A))));
+
   // Op0 * (Op1 * (Op2 * Op3))
   SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op1, MUL23);
   SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);

>From d989ca18b25d0e264cc6b8e61f5f07ea3300bf86 Mon Sep 17 00:00:00 2001
From: Artur Bermond Torres <git-arturbtorres at proton.me>
Date: Thu, 4 Dec 2025 13:10:36 -0300
Subject: [PATCH 2/2] [DAG] SDPatternMatch - Changed SmallVector to ArrayRef
 and added more tests

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h           |  4 ++--
 .../CodeGen/SelectionDAGPatternMatchTest.cpp         | 12 ++++++++++++
 2 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 435f340c5b9c9..9431d542f3267 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1336,7 +1336,7 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
   template <typename MatchContext, typename PatternHd, typename... PatternTl>
   [[nodiscard]] inline bool
   reassociatableMatchHelper(const MatchContext &Ctx,
-                            SmallVector<SDValue> &Leaves, SmallBitVector &Used,
+                            ArrayRef<SDValue> Leaves, SmallBitVector &Used,
                             PatternHd &HeadPattern,
                             PatternTl &...TailPatterns) {
     for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
@@ -1353,7 +1353,7 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
   template <typename MatchContext>
   [[nodiscard]] inline bool
   reassociatableMatchHelper(const MatchContext &Ctx,
-                            SmallVector<SDValue> &Leaves,
+                            ArrayRef<SDValue> Leaves,
                             SmallBitVector &Used) {
     return true;
   }
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index f071b4133e7dc..69e48758b6be5 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -847,8 +847,20 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
   SDValue A, B;
   EXPECT_TRUE(sd_match(
       ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(A))));
+  EXPECT_TRUE(sd_match(Op0, m_Deferred(A)));
+  EXPECT_TRUE(sd_match(Op1, m_Deferred(B)));
+  EXPECT_FALSE(sd_match(Op0, m_Deferred(B)));
+  EXPECT_FALSE(sd_match(Op1, m_Deferred(A)));
+  A.setNode(nullptr);
+  B.setNode(nullptr);
   EXPECT_TRUE(sd_match(
       ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(B))));
+  EXPECT_TRUE(sd_match(Op0, m_Deferred(B)));
+  EXPECT_TRUE(sd_match(Op1, m_Deferred(A)));
+  EXPECT_FALSE(sd_match(Op0, m_Deferred(A)));
+  EXPECT_FALSE(sd_match(Op1, m_Deferred(B)));
+  A.setNode(nullptr);
+  B.setNode(nullptr);
   EXPECT_FALSE(sd_match(
       ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(A), m_Deferred(A))));
 



More information about the llvm-commits mailing list