[llvm] [DAG] SDPatternMatch - add matchers for reassociatable binops (PR #119985)
Ethan Kaji via llvm-commits
llvm-commits at lists.llvm.org
Sat Dec 14 18:25:47 PST 2024
https://github.com/Esan5 updated https://github.com/llvm/llvm-project/pull/119985
>From cdd6b8fa53d9926ad9742ee15e886057c31deaf8 Mon Sep 17 00:00:00 2001
From: Ethan Kaji <ethan.kaji at gmail.com>
Date: Sat, 14 Dec 2024 15:16:32 -0600
Subject: [PATCH 1/2] add reassociatable matchers
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 91 +++++++++++++++++++
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 85 +++++++++++++++++
2 files changed, 176 insertions(+)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index d21cc962da46cb..2332cc89fad211 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1072,6 +1072,97 @@ inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
return m_Xor(V, m_AllOnes());
}
+template <typename... PatternTs> struct ReassociatableOpc_match {
+ unsigned Opcode;
+ std::tuple<PatternTs...> Patterns;
+
+ ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
+ : Opcode(Opcode), Patterns(Patterns...) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ SmallVector<SDValue> Leaves;
+ collectLeaves(N, Leaves);
+ if (Leaves.size() != std::tuple_size_v<std::tuple<PatternTs...>>) {
+ return false;
+ }
+
+ // J in Matches[I] iff sd_context_match(Leaves[I], Ctx,
+ // std::get<J>(Patterns)) == true
+ SmallVector<SmallVector<size_t>> Matches(Leaves.size());
+ for (size_t I = 0; I < Leaves.size(); I += 1) {
+ SmallVector<bool> MatchResults;
+ std::apply(
+ [&](auto &...P) {
+ (MatchResults.emplace_back(sd_context_match(Leaves[I], Ctx, P)),
+ ...);
+ },
+ Patterns);
+ for (size_t J = 0; J < MatchResults.size(); J += 1) {
+ if (MatchResults[J]) {
+ Matches[I].emplace_back(J);
+ }
+ }
+ }
+
+ SmallVector<bool> Used(std::tuple_size_v<std::tuple<PatternTs...>>, false);
+ return reassociatableMatchHelper(Matches, Used);
+ }
+
+ void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
+ if (V->getOpcode() == Opcode) {
+ for (size_t I = 0; I < V->getNumOperands(); I += 1) {
+ collectLeaves(V->getOperand(I), Leaves);
+ }
+ } else {
+ Leaves.emplace_back(V);
+ }
+ }
+
+ [[nodiscard]] inline bool
+ reassociatableMatchHelper(const SmallVector<SmallVector<size_t>> &Matches,
+ SmallVector<bool> &Used, size_t Curr = 0) {
+ if (Curr == Matches.size()) {
+ return true;
+ }
+ for (auto Match : Matches[Curr]) {
+ if (Used[Match]) {
+ continue;
+ }
+ Used[Match] = true;
+ if (reassociatableMatchHelper(Matches, Used, Curr + 1)) {
+ return true;
+ }
+ Used[Match] = false;
+ }
+ return false;
+ }
+};
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableAdd(const PatternTs &...Patterns) {
+ return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableOr(const PatternTs &...Patterns) {
+ return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableAnd(const PatternTs &...Patterns) {
+ return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableMul(const PatternTs &...Patterns) {
+ return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
+}
+
} // namespace SDPatternMatch
} // namespace llvm
#endif
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 259bdad0ab2723..8640c97e48986c 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -576,3 +576,88 @@ TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
EXPECT_TRUE(sd_match(Add, DAG.get(),
m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
}
+
+TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
+ using namespace SDPatternMatch;
+
+ SDLoc DL;
+ auto Int32VT = EVT::getIntegerVT(Context, 32);
+ auto Float32VT = EVT::getFloatingPointVT(32);
+
+ SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+ SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+ SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
+ SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
+
+ // (Op0 + Op1) + (Op2 + Op3)
+ SDValue ADD01 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
+ SDValue ADD23 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, Op3);
+ SDValue ADD = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, ADD23);
+
+ EXPECT_TRUE(sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(ADD23, m_ReassociatableAdd(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ ADD, m_ReassociatableAdd(m_Value(), m_Value(), m_Value(), m_Value())));
+
+ // Op0 + (Op1 + (Op2 + Op3))
+ SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, ADD23);
+ SDValue ADD0123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, ADD123);
+ EXPECT_TRUE(
+ sd_match(ADD123, m_ReassociatableAdd(m_Value(), m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(ADD0123, m_ReassociatableAdd(m_Value(), m_Value(),
+ m_Value(), m_Value())));
+
+ // (Op0 * Op1) * (Op2 * Op3)
+ SDValue MUL01 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, Op1);
+ SDValue MUL23 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, Op3);
+ SDValue MUL = DAG->getNode(ISD::MUL, DL, Int32VT, MUL01, MUL23);
+
+ EXPECT_TRUE(sd_match(MUL01, m_ReassociatableMul(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(MUL23, m_ReassociatableMul(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
+
+ // Op0 * (Op1 * (Op2 * Op3))
+ SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, MUL23);
+ SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
+ EXPECT_TRUE(
+ sd_match(MUL123, m_ReassociatableMul(m_Value(), m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(MUL0123, m_ReassociatableMul(m_Value(), m_Value(),
+ m_Value(), m_Value())));
+
+ // (Op0 && Op1) && (Op2 && Op3)
+ SDValue AND01 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
+ SDValue AND23 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, Op3);
+ SDValue AND = DAG->getNode(ISD::AND, DL, Int32VT, AND01, AND23);
+
+ EXPECT_TRUE(sd_match(AND01, m_ReassociatableAnd(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(AND23, m_ReassociatableAnd(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ AND, m_ReassociatableAnd(m_Value(), m_Value(), m_Value(), m_Value())));
+
+ // Op0 && (Op1 && (Op2 && Op3))
+ SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, AND23);
+ SDValue AND0123 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, AND123);
+ EXPECT_TRUE(
+ sd_match(AND123, m_ReassociatableAnd(m_Value(), m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(AND0123, m_ReassociatableAnd(m_Value(), m_Value(),
+ m_Value(), m_Value())));
+
+ // (Op0 || Op1) || (Op2 || Op3)
+ SDValue OR01 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
+ SDValue OR23 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, Op3);
+ SDValue OR = DAG->getNode(ISD::OR, DL, Int32VT, OR01, OR23);
+
+ EXPECT_TRUE(sd_match(OR01, m_ReassociatableOr(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(OR23, m_ReassociatableOr(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ OR, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+
+ // Op0 || (Op1 || (Op2 || Op3))
+ SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, OR23);
+ SDValue OR0123 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, OR123);
+ EXPECT_TRUE(
+ sd_match(OR123, m_ReassociatableOr(m_Value(), m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ OR0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+}
>From 792f096f0b21522ee3b86e99a88906804019d31b Mon Sep 17 00:00:00 2001
From: Ethan Kaji <ethan.kaji at gmail.com>
Date: Sat, 14 Dec 2024 20:25:24 -0600
Subject: [PATCH 2/2] tests
---
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 8640c97e48986c..05e5b457702f09 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -582,11 +582,10 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
SDLoc DL;
auto Int32VT = EVT::getIntegerVT(Context, 32);
- auto Float32VT = EVT::getFloatingPointVT(32);
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
- SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
+ SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
// (Op0 + Op1) + (Op2 + Op3)
@@ -600,7 +599,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
ADD, m_ReassociatableAdd(m_Value(), m_Value(), m_Value(), m_Value())));
// Op0 + (Op1 + (Op2 + Op3))
- SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, ADD23);
+ SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op1, ADD23);
SDValue ADD0123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, ADD123);
EXPECT_TRUE(
sd_match(ADD123, m_ReassociatableAdd(m_Value(), m_Value(), m_Value())));
@@ -618,7 +617,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
// Op0 * (Op1 * (Op2 * Op3))
- SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, MUL23);
+ SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op1, MUL23);
SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
EXPECT_TRUE(
sd_match(MUL123, m_ReassociatableMul(m_Value(), m_Value(), m_Value())));
@@ -636,7 +635,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
AND, m_ReassociatableAnd(m_Value(), m_Value(), m_Value(), m_Value())));
// Op0 && (Op1 && (Op2 && Op3))
- SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, AND23);
+ SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op1, AND23);
SDValue AND0123 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, AND123);
EXPECT_TRUE(
sd_match(AND123, m_ReassociatableAnd(m_Value(), m_Value(), m_Value())));
@@ -654,7 +653,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
OR, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
// Op0 || (Op1 || (Op2 || Op3))
- SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, OR23);
+ SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op1, OR23);
SDValue OR0123 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, OR123);
EXPECT_TRUE(
sd_match(OR123, m_ReassociatableOr(m_Value(), m_Value(), m_Value())));
More information about the llvm-commits
mailing list