[llvm] [DAG] SDPatternMatch - Replace runtime data structures with lengths known at compile time (PR #172064)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Sun Dec 14 02:12:23 PST 2025


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/172064

>From 74087d84b5d8ba627caa7bbef57f46206b11ba86 Mon Sep 17 00:00:00 2001
From: Artur Bermond Torres <git-arturbtorres at proton.me>
Date: Fri, 12 Dec 2025 15:04:15 -0300
Subject: [PATCH 1/2] [DAG] SDPatternMatch - Replaced some runtime data
 structures whose length is known at compile time

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 43 +++++++++++--------
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  |  2 +
 2 files changed, 28 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index dda3b3827c7aa..510d8ee9bab79 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1310,49 +1310,58 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
   bool match(const MatchContext &Ctx, SDValue N) {
     constexpr size_t NumPatterns = std::tuple_size_v<std::tuple<PatternTs...>>;
 
-    SmallVector<SDValue> Leaves;
-    collectLeaves(N, Leaves);
-    if (Leaves.size() != NumPatterns)
+    std::array<SDValue, NumPatterns> Leaves;
+    size_t LeavesIdx = 0;
+    if (!(collectLeaves(N, Leaves, LeavesIdx) && (LeavesIdx == NumPatterns)))
       return false;
 
-    SmallBitVector Used(NumPatterns);
+    Bitset<NumPatterns> Used;
     return std::apply(
         [&](auto &...P) -> bool {
-          return reassociatableMatchHelper(Ctx, Leaves, Used, P...);
+          return reassociatableMatchHelper<NumPatterns>(Ctx, Leaves, Used,
+                                                        P...);
         },
         Patterns);
   }
 
-  void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
+  template <std::size_t NumPatterns>
+  bool collectLeaves(SDValue V, std::array<SDValue, NumPatterns> &Leaves,
+                     std::size_t &LeafIdx) {
     if (V->getOpcode() == Opcode) {
       for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
-        collectLeaves(V->getOperand(I), Leaves);
+        if ((LeafIdx == NumPatterns) ||
+            !collectLeaves(V->getOperand(I), Leaves, LeafIdx))
+          return false;
     } else {
-      Leaves.emplace_back(V);
+      Leaves[LeafIdx] = V;
+      LeafIdx++;
     }
+    return true;
   }
 
   // Searchs for a matching leaf for every sub-pattern.
-  template <typename MatchContext, typename PatternHd, typename... PatternTl>
+  template <std::size_t NumPatterns, typename MatchContext, typename PatternHd,
+            typename... PatternTl>
   [[nodiscard]] inline bool
   reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
-                            SmallBitVector &Used, PatternHd &HeadPattern,
+                            Bitset<NumPatterns> &Used, PatternHd &HeadPattern,
                             PatternTl &...TailPatterns) {
     for (size_t Match = 0, N = Used.size(); Match < N; Match++) {
       if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern)))
         continue;
-      Used[Match] = true;
-      if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...))
+      Used.set(Match);
+      if (reassociatableMatchHelper<NumPatterns>(Ctx, Leaves, Used,
+                                                 TailPatterns...))
         return true;
-      Used[Match] = false;
+      Used.reset(Match);
     }
     return false;
   }
 
-  template <typename MatchContext>
-  [[nodiscard]] inline bool reassociatableMatchHelper(const MatchContext &Ctx,
-                                                      ArrayRef<SDValue> Leaves,
-                                                      SmallBitVector &Used) {
+  template <std::size_t NumPatterns, typename MatchContext>
+  [[nodiscard]] inline bool
+  reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
+                            Bitset<NumPatterns> &Used) {
     return true;
   }
 };
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 4fcd3fcb8c5c7..1afc034dd7b9e 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -803,6 +803,8 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
   SDValue ADD = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, ADD23);
 
   EXPECT_FALSE(sd_match(ADD01, m_ReassociatableAdd(m_Value())));
+  EXPECT_FALSE(
+      sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(ADD23, m_ReassociatableAdd(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(

>From 94a1f4265d945333b25bf458dc09e7c0cf0dc42b Mon Sep 17 00:00:00 2001
From: Artur Bermond Torres <git-arturbtorres at proton.me>
Date: Fri, 12 Dec 2025 16:35:47 -0300
Subject: [PATCH 2/2] Created static class value NumPatterns and removed the
 now-unecessary template arguments

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h | 16 ++++++----------
 1 file changed, 6 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 510d8ee9bab79..026ee035fcf54 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1302,14 +1302,14 @@ inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
 template <typename... PatternTs> struct ReassociatableOpc_match {
   unsigned Opcode;
   std::tuple<PatternTs...> Patterns;
+  constexpr static size_t NumPatterns =
+      std::tuple_size_v<std::tuple<PatternTs...>>;
 
   ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
       : Opcode(Opcode), Patterns(Patterns...) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
-    constexpr size_t NumPatterns = std::tuple_size_v<std::tuple<PatternTs...>>;
-
     std::array<SDValue, NumPatterns> Leaves;
     size_t LeavesIdx = 0;
     if (!(collectLeaves(N, Leaves, LeavesIdx) && (LeavesIdx == NumPatterns)))
@@ -1318,13 +1318,11 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
     Bitset<NumPatterns> Used;
     return std::apply(
         [&](auto &...P) -> bool {
-          return reassociatableMatchHelper<NumPatterns>(Ctx, Leaves, Used,
-                                                        P...);
+          return reassociatableMatchHelper(Ctx, Leaves, Used, P...);
         },
         Patterns);
   }
 
-  template <std::size_t NumPatterns>
   bool collectLeaves(SDValue V, std::array<SDValue, NumPatterns> &Leaves,
                      std::size_t &LeafIdx) {
     if (V->getOpcode() == Opcode) {
@@ -1340,8 +1338,7 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
   }
 
   // Searchs for a matching leaf for every sub-pattern.
-  template <std::size_t NumPatterns, typename MatchContext, typename PatternHd,
-            typename... PatternTl>
+  template <typename MatchContext, typename PatternHd, typename... PatternTl>
   [[nodiscard]] inline bool
   reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
                             Bitset<NumPatterns> &Used, PatternHd &HeadPattern,
@@ -1350,15 +1347,14 @@ template <typename... PatternTs> struct ReassociatableOpc_match {
       if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern)))
         continue;
       Used.set(Match);
-      if (reassociatableMatchHelper<NumPatterns>(Ctx, Leaves, Used,
-                                                 TailPatterns...))
+      if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...))
         return true;
       Used.reset(Match);
     }
     return false;
   }
 
-  template <std::size_t NumPatterns, typename MatchContext>
+  template <typename MatchContext>
   [[nodiscard]] inline bool
   reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef<SDValue> Leaves,
                             Bitset<NumPatterns> &Used) {



More information about the llvm-commits mailing list