[llvm] 6b00ae6 - [DAG] SDPatternMatch - add matchers for reassociatable binops (#119985)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 21 13:05:11 PDT 2025
Author: Ethan Kaji
Date: 2025-03-21T13:05:06-07:00
New Revision: 6b00ae6359b6442b827d0961357a09ec8dce72a4
URL: https://github.com/llvm/llvm-project/commit/6b00ae6359b6442b827d0961357a09ec8dce72a4
DIFF: https://github.com/llvm/llvm-project/commit/6b00ae6359b6442b827d0961357a09ec8dce72a4.diff
LOG: [DAG] SDPatternMatch - add matchers for reassociatable binops (#119985)
fixes https://github.com/llvm/llvm-project/issues/118847
implements matchers for reassociatable opcodes as well as helpers for
commonly used reassociatable binary matchers.
---------
Co-authored-by: Min-Yih Hsu <min at myhsu.dev>
Added:
Modified:
llvm/include/llvm/CodeGen/SDPatternMatch.h
llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 4488a6152117c..4e1e4fdccfea5 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -14,7 +14,9 @@
#define LLVM_CODEGEN_SDPATTERNMATCH_H
#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/TargetLowering.h"
@@ -1134,6 +1136,87 @@ inline BinaryOpc_match<ValTy, AllOnes_match, true> m_Not(const ValTy &V) {
return m_Xor(V, m_AllOnes());
}
+template <typename... PatternTs> struct ReassociatableOpc_match {
+ unsigned Opcode;
+ std::tuple<PatternTs...> Patterns;
+
+ ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns)
+ : Opcode(Opcode), Patterns(Patterns...) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ SmallVector<SDValue> Leaves;
+ collectLeaves(N, Leaves);
+ if (Leaves.size() != std::tuple_size_v<std::tuple<PatternTs...>>)
+ return false;
+
+ // Matches[I][J] == true iff sd_context_match(Leaves[I], Ctx,
+ // std::get<J>(Patterns)) == true
+ std::array<SmallBitVector, std::tuple_size_v<std::tuple<PatternTs...>>>
+ Matches;
+ for (size_t I = 0, N = Leaves.size(); I < N; I++) {
+ SmallVector<bool> MatchResults;
+ std::apply(
+ [&](auto &...P) {
+ (Matches[I].push_back(sd_context_match(Leaves[I], Ctx, P)), ...);
+ },
+ Patterns);
+ }
+
+ SmallBitVector Used(std::tuple_size_v<std::tuple<PatternTs...>>);
+ return reassociatableMatchHelper(Matches, Used);
+ }
+
+ void collectLeaves(SDValue V, SmallVector<SDValue> &Leaves) {
+ if (V->getOpcode() == Opcode) {
+ for (size_t I = 0, N = V->getNumOperands(); I < N; I++)
+ collectLeaves(V->getOperand(I), Leaves);
+ } else {
+ Leaves.emplace_back(V);
+ }
+ }
+
+ [[nodiscard]] inline bool
+ reassociatableMatchHelper(const ArrayRef<SmallBitVector> Matches,
+ SmallBitVector &Used, size_t Curr = 0) {
+ if (Curr == Matches.size())
+ return true;
+ for (size_t Match = 0, N = Matches[Curr].size(); Match < N; Match++) {
+ if (!Matches[Curr][Match] || Used[Match])
+ continue;
+ Used[Match] = true;
+ if (reassociatableMatchHelper(Matches, Used, Curr + 1))
+ return true;
+ Used[Match] = false;
+ }
+ return false;
+ }
+};
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableAdd(const PatternTs &...Patterns) {
+ return ReassociatableOpc_match<PatternTs...>(ISD::ADD, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableOr(const PatternTs &...Patterns) {
+ return ReassociatableOpc_match<PatternTs...>(ISD::OR, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableAnd(const PatternTs &...Patterns) {
+ return ReassociatableOpc_match<PatternTs...>(ISD::AND, Patterns...);
+}
+
+template <typename... PatternTs>
+inline ReassociatableOpc_match<PatternTs...>
+m_ReassociatableMul(const PatternTs &...Patterns) {
+ return ReassociatableOpc_match<PatternTs...>(ISD::MUL, Patterns...);
+}
+
} // namespace SDPatternMatch
} // namespace llvm
#endif
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 51fec9b68b558..df35a8678e0a4 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -651,3 +651,128 @@ TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
EXPECT_TRUE(sd_match(Add, DAG.get(),
m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
}
+
+TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
+ using namespace SDPatternMatch;
+
+ SDLoc DL;
+ auto Int32VT = EVT::getIntegerVT(Context, 32);
+
+ SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+ SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+ SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
+ SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
+
+ // (Op0 + Op1) + (Op2 + Op3)
+ SDValue ADD01 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
+ SDValue ADD23 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, Op3);
+ SDValue ADD = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, ADD23);
+
+ EXPECT_FALSE(sd_match(ADD01, m_ReassociatableAdd(m_Value())));
+ EXPECT_TRUE(sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(ADD23, m_ReassociatableAdd(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ ADD, m_ReassociatableAdd(m_Value(), m_Value(), m_Value(), m_Value())));
+
+ // Op0 + (Op1 + (Op2 + Op3))
+ SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op1, ADD23);
+ SDValue ADD0123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, ADD123);
+ EXPECT_TRUE(
+ sd_match(ADD123, m_ReassociatableAdd(m_Value(), m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(ADD0123, m_ReassociatableAdd(m_Value(), m_Value(),
+ m_Value(), m_Value())));
+
+ // (Op0 - Op1) + (Op2 - Op3)
+ SDValue SUB01 = DAG->getNode(ISD::SUB, DL, Int32VT, Op0, Op1);
+ SDValue SUB23 = DAG->getNode(ISD::SUB, DL, Int32VT, Op2, Op3);
+ SDValue ADDS0123 = DAG->getNode(ISD::ADD, DL, Int32VT, SUB01, SUB23);
+
+ EXPECT_FALSE(sd_match(SUB01, m_ReassociatableAdd(m_Value(), m_Value())));
+ EXPECT_FALSE(sd_match(ADDS0123, m_ReassociatableAdd(m_Value(), m_Value(),
+ m_Value(), m_Value())));
+
+ // SUB + SUB matches (Op0 - Op1) + (Op2 - Op3)
+ EXPECT_TRUE(
+ sd_match(ADDS0123, m_ReassociatableAdd(m_Sub(m_Value(), m_Value()),
+ m_Sub(m_Value(), m_Value()))));
+ EXPECT_FALSE(sd_match(ADDS0123, m_ReassociatableAdd(m_Value(), m_Value(),
+ m_Value(), m_Value())));
+
+ // (Op0 * Op1) * (Op2 * Op3)
+ SDValue MUL01 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, Op1);
+ SDValue MUL23 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, Op3);
+ SDValue MUL = DAG->getNode(ISD::MUL, DL, Int32VT, MUL01, MUL23);
+
+ EXPECT_TRUE(sd_match(MUL01, m_ReassociatableMul(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(MUL23, m_ReassociatableMul(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
+
+ // Op0 * (Op1 * (Op2 * Op3))
+ SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op1, MUL23);
+ SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
+ EXPECT_TRUE(
+ sd_match(MUL123, m_ReassociatableMul(m_Value(), m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(MUL0123, m_ReassociatableMul(m_Value(), m_Value(),
+ m_Value(), m_Value())));
+
+ // (Op0 - Op1) * (Op2 - Op3)
+ SDValue MULS0123 = DAG->getNode(ISD::MUL, DL, Int32VT, SUB01, SUB23);
+ EXPECT_TRUE(
+ sd_match(MULS0123, m_ReassociatableMul(m_Sub(m_Value(), m_Value()),
+ m_Sub(m_Value(), m_Value()))));
+ EXPECT_FALSE(sd_match(MULS0123, m_ReassociatableMul(m_Value(), m_Value(),
+ m_Value(), m_Value())));
+
+ // (Op0 && Op1) && (Op2 && Op3)
+ SDValue AND01 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
+ SDValue AND23 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, Op3);
+ SDValue AND = DAG->getNode(ISD::AND, DL, Int32VT, AND01, AND23);
+
+ EXPECT_TRUE(sd_match(AND01, m_ReassociatableAnd(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(AND23, m_ReassociatableAnd(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ AND, m_ReassociatableAnd(m_Value(), m_Value(), m_Value(), m_Value())));
+
+ // Op0 && (Op1 && (Op2 && Op3))
+ SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op1, AND23);
+ SDValue AND0123 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, AND123);
+ EXPECT_TRUE(
+ sd_match(AND123, m_ReassociatableAnd(m_Value(), m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(AND0123, m_ReassociatableAnd(m_Value(), m_Value(),
+ m_Value(), m_Value())));
+
+ // (Op0 - Op1) && (Op2 - Op3)
+ SDValue ANDS0123 = DAG->getNode(ISD::AND, DL, Int32VT, SUB01, SUB23);
+ EXPECT_TRUE(
+ sd_match(ANDS0123, m_ReassociatableAnd(m_Sub(m_Value(), m_Value()),
+ m_Sub(m_Value(), m_Value()))));
+ EXPECT_FALSE(sd_match(ANDS0123, m_ReassociatableAnd(m_Value(), m_Value(),
+ m_Value(), m_Value())));
+
+ // (Op0 || Op1) || (Op2 || Op3)
+ SDValue OR01 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
+ SDValue OR23 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, Op3);
+ SDValue OR = DAG->getNode(ISD::OR, DL, Int32VT, OR01, OR23);
+
+ EXPECT_TRUE(sd_match(OR01, m_ReassociatableOr(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(OR23, m_ReassociatableOr(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ OR, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+
+ // Op0 || (Op1 || (Op2 || Op3))
+ SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op1, OR23);
+ SDValue OR0123 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, OR123);
+ EXPECT_TRUE(
+ sd_match(OR123, m_ReassociatableOr(m_Value(), m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(
+ OR0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+
+ // (Op0 - Op1) || (Op2 - Op3)
+ SDValue ORS0123 = DAG->getNode(ISD::OR, DL, Int32VT, SUB01, SUB23);
+ EXPECT_TRUE(
+ sd_match(ORS0123, m_ReassociatableOr(m_Sub(m_Value(), m_Value()),
+ m_Sub(m_Value(), m_Value()))));
+ EXPECT_FALSE(sd_match(
+ ORS0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+}
More information about the llvm-commits
mailing list