[llvm] [DAG] SDPatternMatch - Fix m_Reassociatable mismatching (PR #170061)

Artur Bermond Torres via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 30 16:39:36 PST 2025


https://github.com/bermondd created https://github.com/llvm/llvm-project/pull/170061

Fixes #169645

The issue was caused by a for-loop improperly overwriting SDValue binds when m_Reassociatable is given two or more patterns that (1) call m_Value with an SDValue parameter and (2) differ by that parameter. This fix comes with added unit tests relevant to SDValue bindings inside m_Reassociatable patterns.

Essentially, the original implementation first tried to match every combination of leaf node and pattern possible and stored that in a matrix-like structure, and then did a recursive search on that matrix to check if it's possible to pair every leaf with a pattern. The problem is that m_Value has a side effect where it changes an SDValue, and the creation of this matrix was corrupting these values. Below is an example of this, following the order of execution in the original implementation and using the case brought by issue #169645, where this behavior was found. The example tries to match ((a >> 1) + (b >> 1) + (a & b & 1)), using uppercase letters for the SDValue variables themselves and lowercase for their values. The result is that the pattern matches the same value for A and B, which was the behavior observed in the issue:

| Line | Leaf | Pattern | Match? | Effect |
|--------|--------|--------|--------|--------|
| 1 | a >> 1 | m_Srl(m_Value(A), m_One()) | Yes | A <- a |
| 2 | a >> 1 | m_Srl(m_Value(B), m_One()) | Yes | B <- a |
| 3 | a >> 1 | m_ReassociableAnd(m_Deferred(A), m_Deferred(B), m_One()) | No | -- |
| 4 | b >> 1 | m_Srl(m_Value(A), m_One()) | Yes | A <- b |
| 5 | b >> 1 | m_Srl(m_Value(B), m_One()) | Yes | B <- b |
| 6 | b >> 1 | m_ReassociableAnd(m_Deferred(A), m_Deferred(B), m_One()) | No | -- |
| 7 | a & b & 1 | m_Srl(m_Value(A), m_One()) | No | -- |
| 8 | a & b & 1 | m_Srl(m_Value(B), m_One()) | No | -- |
| 9 | a & b & 1 | m_ReassociableAnd(m_Deferred(A), m_Deferred(B), m_One()) | a == b | -- | 

To fix this, the function now matches the patterns during the recursive search itself, instead of preparing the matrix beforehand. Although this does fix the issue, it does mean that we're performing a best case of n and worst case of n! matching attempts, instead of the fixed nˆ2 in the original, where n is the number of patterns provided. Going back to the table above, using this fix the lines 2, 3, 4, 6, 7, and 8 do not happen, and so the only effects happening are A <- a and B <- b, which then will result in line 9 matching correctly.

Maybe the way this was fixed could be improved, tho. This is my first contribution to LLVM and I'm still fairly inexperienced with C++, so I'm not sure if template recursion is discouraged, but it was the way I found to abide by the compile time constraints of the template. This was also my first ever PR to any open source project, so I don't know if this is too length, sorry if so. Open to any feedback :)

>From ab6d06a1b3c810cbaa1d16cd46267a56c03fb3ce Mon Sep 17 00:00:00 2001
From: Artur Bermond Torres <41002679+bermondd at users.noreply.github.com>
Date: Sun, 30 Nov 2025 19:36:53 -0300
Subject: [PATCH] [DAG] SDPatternMatch - Fix m_Reassociatable mismatching

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 41 ++++++++++---------
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  | 10 +++++
 2 files changed, 32 insertions(+), 19 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index daafd3fc9d825..435f340c5b9c9 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1315,19 +1315,12 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
     if (Leaves.size() != NumPatterns)
       return false;
 
-    // Matches[I][J] == true iff sd_context_match(Leaves[I], Ctx,
-    // std::get<J>(Patterns)) == true
-    std::array<SmallBitVector, NumPatterns> Matches;
-    for (size_t I = 0; I != NumPatterns; I++) {
-      std::apply(
-          [&](auto &...P) {
-            (Matches[I].push_back(sd_context_match(Leaves[I], Ctx, P)), ...);
-          },
-          Patterns);
-    }
-
     SmallBitVector Used(NumPatterns);
-    return reassociatableMatchHelper(Matches, Used);
+    return std::apply(
+        [&](auto &...P) -> bool {
+          return reassociatableMatchHelper(Ctx, Leaves, Used, P...);
+        },
+        Patterns);
   }
 
   void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
@@ -1339,21 +1332,31 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
     }
   }
 
+  // Searchs for a matching leaf for every sub-pattern.
+  template <typename MatchContext, typename PatternHd, typename... PatternTl>
   [[nodiscard]] inline bool
-  reassociatableMatchHelper(ArrayRef<SmallBitVector> Matches,
-                            SmallBitVector &Used, size_t Curr = 0) {
-    if (Curr == Matches.size())
-      return true;
-    for (size_t Match = 0, N = Matches[Curr].size(); Match < N; Match++) {
-      if (!Matches[Curr][Match] || Used[Match])
+  reassociatableMatchHelper(const MatchContext &Ctx,
+                            SmallVector<SDValue> &Leaves, SmallBitVector &Used,
+                            PatternHd &HeadPattern,
+                            PatternTl &...TailPatterns) {
+    for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
+      if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern)))
         continue;
       Used[Match] = true;
-      if (reassociatableMatchHelper(Matches, Used, Curr + 1))
+      if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...))
         return true;
       Used[Match] = false;
     }
     return false;
   }
+
+  template <typename MatchContext>
+  [[nodiscard]] inline bool
+  reassociatableMatchHelper(const MatchContext &Ctx,
+                            SmallVector<SDValue> &Leaves,
+                            SmallBitVector &Used) {
+    return true;
+  }
 };
 
 template <typename... PatternTs>
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index c32ceee73472d..f071b4133e7dc 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -842,6 +842,16 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
   EXPECT_TRUE(sd_match(
       MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
 
+  // (Op0 + Op1) + Op0 binds correctly, allowing commutation
+  SDValue ADD010 = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, Op0);
+  SDValue A, B;
+  EXPECT_TRUE(sd_match(
+      ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(A))));
+  EXPECT_TRUE(sd_match(
+      ADD010, m_ReassociatableAdd(m_Value(A), m_Value(B), m_Deferred(B))));
+  EXPECT_FALSE(sd_match(
+      ADD010, m_ReassociatableAdd(m_Value(A), m_Deferred(A), m_Deferred(A))));
+
   // Op0 * (Op1 * (Op2 * Op3))
   SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op1, MUL23);
   SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);



More information about the llvm-commits mailing list