[llvm] [DAG] SDPatternMatch - add matchers for reassociatable binops (PR #119985)

Ethan Kaji via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 14 18:27:29 PST 2024


https://github.com/Esan5 updated https://github.com/llvm/llvm-project/pull/119985

>From 8d928e4528bb44e0f0730db250d8b59af1221eb7 Mon Sep 17 00:00:00 2001
From: Ethan Kaji <ethan.kaji at gmail.com>
Date: Sat, 14 Dec 2024 15:16:32 -0600
Subject: [PATCH 1/2] add reassociatable matchers

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 91 +++++++++++++++++++
 .../CodeGen/SelectionDAGPatternMatchTest.cpp  | 85 +++++++++++++++++
 2 files changed, 176 insertions(+)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index d21cc962da46cb..2332cc89fad211 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1072,6 +1072,97 @@ inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
   return m_Xor(V, m_AllOnes());
 }
 
+template <typename... PatternTs> struct ReassociatableOpc_match {
+  unsigned Opcode;
+  std::tuple<PatternTs...> Patterns;
+
+  ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
+      : Opcode(Opcode), Patterns(Patterns...) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    SmallVector<SDValue> Leaves;
+    collectLeaves(N, Leaves);
+    if (Leaves.size() != std::tuple_size_v<std::tuple<PatternTs...>>) {
+      return false;
+    }
+
+    // J in Matches[I] iff sd_context_match(Leaves[I], Ctx,
+    // std::get<J>(Patterns)) == true
+    SmallVector<SmallVector<size_t>> Matches(Leaves.size());
+    for (size_t I = 0; I < Leaves.size(); I += 1) {
+      SmallVector<bool> MatchResults;
+      std::apply(
+          [&](auto &...P) {
+            (MatchResults.emplace_back(sd_context_match(Leaves[I], Ctx, P)),
+             ...);
+          },
+          Patterns);
+      for (size_t J = 0; J < MatchResults.size(); J += 1) {
+        if (MatchResults[J]) {
+          Matches[I].emplace_back(J);
+        }
+      }
+    }
+
+    SmallVector<bool> Used(std::tuple_size_v<std::tuple<PatternTs...>>, false);
+    return reassociatableMatchHelper(Matches, Used);
+  }
+
+  void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
+    if (V->getOpcode() == Opcode) {
+      for (size_t I = 0; I < V->getNumOperands(); I += 1) {
+        collectLeaves(V->getOperand(I), Leaves);
+      }
+    } else {
+      Leaves.emplace_back(V);
+    }
+  }
+
+  [[nodiscard]] inline bool
+  reassociatableMatchHelper(const SmallVector<SmallVector<size_t>> &Matches,
+                            SmallVector<bool> &Used, size_t Curr = 0) {
+    if (Curr == Matches.size()) {
+      return true;
+    }
+    for (auto Match : Matches[Curr]) {
+      if (Used[Match]) {
+        continue;
+      }
+      Used[Match] = true;
+      if (reassociatableMatchHelper(Matches, Used, Curr + 1)) {
+        return true;
+      }
+      Used[Match] = false;
+    }
+    return false;
+  }
+};
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableAdd(const PatternTs &...Patterns) {
+  return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableOr(const PatternTs &...Patterns) {
+  return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableAnd(const PatternTs &...Patterns) {
+  return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableMul(const PatternTs &...Patterns) {
+  return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
+}
+
 } // namespace SDPatternMatch
 } // namespace llvm
 #endif
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 259bdad0ab2723..8640c97e48986c 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -576,3 +576,88 @@ TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
   EXPECT_TRUE(sd_match(Add, DAG.get(),
                        m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
 }
+
+TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
+  using namespace SDPatternMatch;
+
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+  auto Float32VT = EVT::getFloatingPointVT(32);
+
+  SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+  SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+  SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
+  SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
+
+  // (Op0 + Op1) + (Op2 + Op3)
+  SDValue ADD01 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
+  SDValue ADD23 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, Op3);
+  SDValue ADD = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, ADD23);
+
+  EXPECT_TRUE(sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(ADD23, m_ReassociatableAdd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      ADD, m_ReassociatableAdd(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 + (Op1 + (Op2 + Op3))
+  SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, ADD23);
+  SDValue ADD0123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, ADD123);
+  EXPECT_TRUE(
+      sd_match(ADD123, m_ReassociatableAdd(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(ADD0123, m_ReassociatableAdd(m_Value(), m_Value(),
+                                                    m_Value(), m_Value())));
+
+  // (Op0 * Op1) * (Op2 * Op3)
+  SDValue MUL01 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, Op1);
+  SDValue MUL23 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, Op3);
+  SDValue MUL = DAG->getNode(ISD::MUL, DL, Int32VT, MUL01, MUL23);
+
+  EXPECT_TRUE(sd_match(MUL01, m_ReassociatableMul(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(MUL23, m_ReassociatableMul(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 * (Op1 * (Op2 * Op3))
+  SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, MUL23);
+  SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
+  EXPECT_TRUE(
+      sd_match(MUL123, m_ReassociatableMul(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(MUL0123, m_ReassociatableMul(m_Value(), m_Value(),
+                                                    m_Value(), m_Value())));
+
+  // (Op0 && Op1) && (Op2 && Op3)
+  SDValue AND01 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
+  SDValue AND23 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, Op3);
+  SDValue AND = DAG->getNode(ISD::AND, DL, Int32VT, AND01, AND23);
+
+  EXPECT_TRUE(sd_match(AND01, m_ReassociatableAnd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(AND23, m_ReassociatableAnd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      AND, m_ReassociatableAnd(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 && (Op1 && (Op2 && Op3))
+  SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, AND23);
+  SDValue AND0123 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, AND123);
+  EXPECT_TRUE(
+      sd_match(AND123, m_ReassociatableAnd(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(AND0123, m_ReassociatableAnd(m_Value(), m_Value(),
+                                                    m_Value(), m_Value())));
+
+  // (Op0 || Op1) || (Op2 || Op3)
+  SDValue OR01 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
+  SDValue OR23 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, Op3);
+  SDValue OR = DAG->getNode(ISD::OR, DL, Int32VT, OR01, OR23);
+
+  EXPECT_TRUE(sd_match(OR01, m_ReassociatableOr(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(OR23, m_ReassociatableOr(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      OR, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 || (Op1 || (Op2 || Op3))
+  SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, OR23);
+  SDValue OR0123 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, OR123);
+  EXPECT_TRUE(
+      sd_match(OR123, m_ReassociatableOr(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      OR0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+}

>From ed46b6e6128bd5e6ef0dd33921a39d1af46fb9c5 Mon Sep 17 00:00:00 2001
From: Ethan Kaji <ethan.kaji at gmail.com>
Date: Sat, 14 Dec 2024 20:25:24 -0600
Subject: [PATCH 2/2] tests

---
 .../CodeGen/SelectionDAGPatternMatchTest.cpp          | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 8640c97e48986c..05e5b457702f09 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -582,11 +582,10 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
 
   SDLoc DL;
   auto Int32VT = EVT::getIntegerVT(Context, 32);
-  auto Float32VT = EVT::getFloatingPointVT(32);
 
   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
-  SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
+  SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
   SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
 
   // (Op0 + Op1) + (Op2 + Op3)
@@ -600,7 +599,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
       ADD, m_ReassociatableAdd(m_Value(), m_Value(), m_Value(), m_Value())));
 
   // Op0 + (Op1 + (Op2 + Op3))
-  SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, ADD23);
+  SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op1, ADD23);
   SDValue ADD0123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, ADD123);
   EXPECT_TRUE(
       sd_match(ADD123, m_ReassociatableAdd(m_Value(), m_Value(), m_Value())));
@@ -618,7 +617,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
       MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
 
   // Op0 * (Op1 * (Op2 * Op3))
-  SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, MUL23);
+  SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op1, MUL23);
   SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
   EXPECT_TRUE(
       sd_match(MUL123, m_ReassociatableMul(m_Value(), m_Value(), m_Value())));
@@ -636,7 +635,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
       AND, m_ReassociatableAnd(m_Value(), m_Value(), m_Value(), m_Value())));
 
   // Op0 && (Op1 && (Op2 && Op3))
-  SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, AND23);
+  SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op1, AND23);
   SDValue AND0123 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, AND123);
   EXPECT_TRUE(
       sd_match(AND123, m_ReassociatableAnd(m_Value(), m_Value(), m_Value())));
@@ -654,7 +653,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
       OR, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
 
   // Op0 || (Op1 || (Op2 || Op3))
-  SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, OR23);
+  SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op1, OR23);
   SDValue OR0123 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, OR123);
   EXPECT_TRUE(
       sd_match(OR123, m_ReassociatableOr(m_Value(), m_Value(), m_Value())));



More information about the llvm-commits mailing list