[llvm] [DAG] SDPatternMatch - add matchers for reassociatable binops (PR #119985)

Min-Yih Hsu via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 2 16:52:31 PST 2025


================
@@ -576,3 +576,87 @@ TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) {
   EXPECT_TRUE(sd_match(Add, DAG.get(),
                        m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value())))));
 }
+
+TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
+  using namespace SDPatternMatch;
+
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+
+  SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+  SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+  SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
+  SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
+
+  // (Op0 + Op1) + (Op2 + Op3)
+  SDValue ADD01 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
+  SDValue ADD23 = DAG->getNode(ISD::ADD, DL, Int32VT, Op2, Op3);
+  SDValue ADD = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, ADD23);
+
+  EXPECT_TRUE(sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(ADD23, m_ReassociatableAdd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      ADD, m_ReassociatableAdd(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 + (Op1 + (Op2 + Op3))
+  SDValue ADD123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op1, ADD23);
+  SDValue ADD0123 = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, ADD123);
+  EXPECT_TRUE(
+      sd_match(ADD123, m_ReassociatableAdd(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(ADD0123, m_ReassociatableAdd(m_Value(), m_Value(),
+                                                    m_Value(), m_Value())));
+
+  // (Op0 * Op1) * (Op2 * Op3)
+  SDValue MUL01 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, Op1);
+  SDValue MUL23 = DAG->getNode(ISD::MUL, DL, Int32VT, Op2, Op3);
+  SDValue MUL = DAG->getNode(ISD::MUL, DL, Int32VT, MUL01, MUL23);
+
+  EXPECT_TRUE(sd_match(MUL01, m_ReassociatableMul(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(MUL23, m_ReassociatableMul(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      MUL, m_ReassociatableMul(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 * (Op1 * (Op2 * Op3))
+  SDValue MUL123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op1, MUL23);
+  SDValue MUL0123 = DAG->getNode(ISD::MUL, DL, Int32VT, Op0, MUL123);
+  EXPECT_TRUE(
+      sd_match(MUL123, m_ReassociatableMul(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(MUL0123, m_ReassociatableMul(m_Value(), m_Value(),
+                                                    m_Value(), m_Value())));
+
+  // (Op0 && Op1) && (Op2 && Op3)
+  SDValue AND01 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
+  SDValue AND23 = DAG->getNode(ISD::AND, DL, Int32VT, Op2, Op3);
+  SDValue AND = DAG->getNode(ISD::AND, DL, Int32VT, AND01, AND23);
+
+  EXPECT_TRUE(sd_match(AND01, m_ReassociatableAnd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(AND23, m_ReassociatableAnd(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      AND, m_ReassociatableAnd(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 && (Op1 && (Op2 && Op3))
+  SDValue AND123 = DAG->getNode(ISD::AND, DL, Int32VT, Op1, AND23);
+  SDValue AND0123 = DAG->getNode(ISD::AND, DL, Int32VT, Op0, AND123);
+  EXPECT_TRUE(
+      sd_match(AND123, m_ReassociatableAnd(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(AND0123, m_ReassociatableAnd(m_Value(), m_Value(),
+                                                    m_Value(), m_Value())));
+
+  // (Op0 || Op1) || (Op2 || Op3)
+  SDValue OR01 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
+  SDValue OR23 = DAG->getNode(ISD::OR, DL, Int32VT, Op2, Op3);
+  SDValue OR = DAG->getNode(ISD::OR, DL, Int32VT, OR01, OR23);
+
+  EXPECT_TRUE(sd_match(OR01, m_ReassociatableOr(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(OR23, m_ReassociatableOr(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      OR, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+
+  // Op0 || (Op1 || (Op2 || Op3))
+  SDValue OR123 = DAG->getNode(ISD::OR, DL, Int32VT, Op1, OR23);
+  SDValue OR0123 = DAG->getNode(ISD::OR, DL, Int32VT, Op0, OR123);
+  EXPECT_TRUE(
+      sd_match(OR123, m_ReassociatableOr(m_Value(), m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(
+      OR0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
+}
----------------
mshockwave wrote:

could you add some negative tests showing that it works as expected even when there are non-associative operations in the expression tree?

https://github.com/llvm/llvm-project/pull/119985


More information about the llvm-commits mailing list