[llvm] 6b00ae6 - [DAG] SDPatternMatch - add matchers for reassociatable binops (#119985)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 21 13:05:11 PDT 2025


Author: Ethan Kaji
Date: 2025-03-21T13:05:06-07:00
New Revision: 6b00ae6359b6442b827d0961357a09ec8dce72a4

URL: https://github.com/llvm/llvm-project/commit/6b00ae6359b6442b827d0961357a09ec8dce72a4
DIFF: https://github.com/llvm/llvm-project/commit/6b00ae6359b6442b827d0961357a09ec8dce72a4.diff

LOG: [DAG] SDPatternMatch - add matchers for reassociatable binops (#119985)

fixes https://github.com/llvm/llvm-project/issues/118847

implements matchers for reassociatable opcodes as well as helpers for
commonly used reassociatable binary matchers.

---------

Co-authored-by: Min-Yih Hsu <min at myhsu.dev>

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SDPatternMatch.h
    llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 4488a6152117c..4e1e4fdccfea5 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -14,7 +14,9 @@
 #define LLVM_CODEGEN_SDPATTERNMATCH_H
 
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/TargetLowering.h"
@@ -1134,6 +1136,87 @@ inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
   return m_Xor(V, m_AllOnes());
 }
 
+template <typename... PatternTs> struct ReassociatableOpc_match {
+  unsigned Opcode;
+  std::tuple<PatternTs...> Patterns;
+
+  ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
+      : Opcode(Opcode), Patterns(Patterns...) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    SmallVector<SDValue> Leaves;
+    collectLeaves(N, Leaves);
+    if (Leaves.size() != std::tuple_size_v<std::tuple<PatternTs...>>)
+      return false;
+
+    // Matches[I][J] == true iff sd_context_match(Leaves[I], Ctx,
+    // std::get<J>(Patterns)) == true
+    std::array<SmallBitVector, std::tuple_size_v<std::tuple<PatternTs...>>>
+        Matches;
+    for (size_t I = 0, N = Leaves.size(); I < N; I++) {
+      SmallVector<bool> MatchResults;
+      std::apply(
+          [&](auto &...P) {
+            (Matches[I].push_back(sd_context_match(Leaves[I], Ctx, P)), ...);
+          },
+          Patterns);
+    }
+
+    SmallBitVector Used(std::tuple_size_v<std::tuple<PatternTs...>>);
+    return reassociatableMatchHelper(Matches, Used);
+  }
+
+  void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
+    if (V->getOpcode() == Opcode) {
+      for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
+        collectLeaves(V->getOperand(I), Leaves);
+    } else {
+      Leaves.emplace_back(V);
+    }
+  }
+
+  [[nodiscard]] inline bool
+  reassociatableMatchHelper(const ArrayRef<SmallBitVector> Matches,
+                            SmallBitVector &Used, size_t Curr = 0) {
+    if (Curr == Matches.size())
+      return true;
+    for (size_t Match = 0, N = Matches[Curr].size(); Match < N; Match++) {
+      if (!Matches[Curr][Match] || Used[Match])
+        continue;
+      Used[Match] = true;
+      if (reassociatableMatchHelper(Matches, Used, Curr + 1))
+        return true;
+      Used[Match] = false;
+    }
+    return false;
+  }
+};
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableAdd(const PatternTs &...Patterns) {
+  return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableOr(const PatternTs &...Patterns) {
+  return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableAnd(const PatternTs &...Patterns) {
+  return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableMul(const PatternTs &...Patterns) {
+  return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
+}
+
 } // namespace SDPatternMatch
 } // namespace llvm
 #endif

diff  --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 51fec9b68b558..df35a8678e0a4 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -651,3 +651,128 @@ TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
   EXPECT_TRUE(sd_match(Add, DAG.get(),
                        m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
 }
+
+TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
+  using namespace SDPatternMatch;
+
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+
+  SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+  SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+  SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
+  SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
+
+  // (Op0 + Op1) + (Op2 + Op3)
+  SDValue ADD01 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
+  SDValue ADD23 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, Op3);
+  SDValue ADD = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, ADD23);
+
+  EXPECT_FALSE(sd_match(ADD01, m_ReassociatableAdd(m_Value())));
+  EXPECT_TRUE(sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(ADD23, m_ReassociatableAdd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      ADD, m_ReassociatableAdd(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 + (Op1 + (Op2 + Op3))
+  SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op1, ADD23);
+  SDValue ADD0123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, ADD123);
+  EXPECT_TRUE(
+      sd_match(ADD123, m_ReassociatableAdd(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(ADD0123, m_ReassociatableAdd(m_Value(), m_Value(),
+                                                    m_Value(), m_Value())));
+
+  // (Op0 - Op1) + (Op2 - Op3)
+  SDValue SUB01 = DAG->getNode(ISD::SUB, DL, Int32VT, Op0, Op1);
+  SDValue SUB23 = DAG->getNode(ISD::SUB, DL, Int32VT, Op2, Op3);
+  SDValue ADDS0123 = DAG->getNode(ISD::ADD, DL, Int32VT, SUB01, SUB23);
+
+  EXPECT_FALSE(sd_match(SUB01, m_ReassociatableAdd(m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(ADDS0123, m_ReassociatableAdd(m_Value(), m_Value(),
+                                                      m_Value(), m_Value())));
+
+  // SUB + SUB matches (Op0 - Op1) + (Op2 - Op3)
+  EXPECT_TRUE(
+      sd_match(ADDS0123, m_ReassociatableAdd(m_Sub(m_Value(), m_Value()),
+                                             m_Sub(m_Value(), m_Value()))));
+  EXPECT_FALSE(sd_match(ADDS0123, m_ReassociatableAdd(m_Value(), m_Value(),
+                                                      m_Value(), m_Value())));
+
+  // (Op0 * Op1) * (Op2 * Op3)
+  SDValue MUL01 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, Op1);
+  SDValue MUL23 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, Op3);
+  SDValue MUL = DAG->getNode(ISD::MUL, DL, Int32VT, MUL01, MUL23);
+
+  EXPECT_TRUE(sd_match(MUL01, m_ReassociatableMul(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(MUL23, m_ReassociatableMul(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 * (Op1 * (Op2 * Op3))
+  SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op1, MUL23);
+  SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
+  EXPECT_TRUE(
+      sd_match(MUL123, m_ReassociatableMul(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(MUL0123, m_ReassociatableMul(m_Value(), m_Value(),
+                                                    m_Value(), m_Value())));
+
+  // (Op0 - Op1) * (Op2 - Op3)
+  SDValue MULS0123 = DAG->getNode(ISD::MUL, DL, Int32VT, SUB01, SUB23);
+  EXPECT_TRUE(
+      sd_match(MULS0123, m_ReassociatableMul(m_Sub(m_Value(), m_Value()),
+                                             m_Sub(m_Value(), m_Value()))));
+  EXPECT_FALSE(sd_match(MULS0123, m_ReassociatableMul(m_Value(), m_Value(),
+                                                      m_Value(), m_Value())));
+
+  // (Op0 && Op1) && (Op2 && Op3)
+  SDValue AND01 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
+  SDValue AND23 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, Op3);
+  SDValue AND = DAG->getNode(ISD::AND, DL, Int32VT, AND01, AND23);
+
+  EXPECT_TRUE(sd_match(AND01, m_ReassociatableAnd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(AND23, m_ReassociatableAnd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      AND, m_ReassociatableAnd(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 && (Op1 && (Op2 && Op3))
+  SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op1, AND23);
+  SDValue AND0123 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, AND123);
+  EXPECT_TRUE(
+      sd_match(AND123, m_ReassociatableAnd(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(AND0123, m_ReassociatableAnd(m_Value(), m_Value(),
+                                                    m_Value(), m_Value())));
+
+  // (Op0 - Op1) && (Op2 - Op3)
+  SDValue ANDS0123 = DAG->getNode(ISD::AND, DL, Int32VT, SUB01, SUB23);
+  EXPECT_TRUE(
+      sd_match(ANDS0123, m_ReassociatableAnd(m_Sub(m_Value(), m_Value()),
+                                             m_Sub(m_Value(), m_Value()))));
+  EXPECT_FALSE(sd_match(ANDS0123, m_ReassociatableAnd(m_Value(), m_Value(),
+                                                      m_Value(), m_Value())));
+
+  // (Op0 || Op1) || (Op2 || Op3)
+  SDValue OR01 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
+  SDValue OR23 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, Op3);
+  SDValue OR = DAG->getNode(ISD::OR, DL, Int32VT, OR01, OR23);
+
+  EXPECT_TRUE(sd_match(OR01, m_ReassociatableOr(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(OR23, m_ReassociatableOr(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      OR, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 || (Op1 || (Op2 || Op3))
+  SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op1, OR23);
+  SDValue OR0123 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, OR123);
+  EXPECT_TRUE(
+      sd_match(OR123, m_ReassociatableOr(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      OR0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // (Op0 - Op1) || (Op2 - Op3)
+  SDValue ORS0123 = DAG->getNode(ISD::OR, DL, Int32VT, SUB01, SUB23);
+  EXPECT_TRUE(
+      sd_match(ORS0123, m_ReassociatableOr(m_Sub(m_Value(), m_Value()),
+                                           m_Sub(m_Value(), m_Value()))));
+  EXPECT_FALSE(sd_match(
+      ORS0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+}


        


More information about the llvm-commits mailing list