[llvm] [DAG] Add SDPatternMatch::m_SetCC and update some combines to use it (PR #98646)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Sun Jul 14 04:24:56 PDT 2024
https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/98646
>From 9d39eafdef86ef8ae5f74a9a86a00c4e2b46d4da Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 12 Jul 2024 16:02:37 +0100
Subject: [PATCH 1/3] [DAG] Add SDPatternMatch::m_SetCC and update some
combines to use it
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 43 ++++++++++++++++++
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 45 ++++++-------------
2 files changed, 57 insertions(+), 31 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index f39fbd95b3beb..07204d1f48c24 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -447,6 +447,49 @@ template <> struct EffectiveOperands<false> {
explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
};
+// === Ternary operations ===
+template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
+ bool ExcludeChain = false>
+struct TernaryOpc_match {
+ unsigned Opcode;
+ T0_P Op0;
+ T1_P Op1;
+ T2_P Op2;
+
+ TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
+ const T2_P &Op2)
+ : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
+ EffectiveOperands<ExcludeChain> EO(N);
+ assert(EO.Size == 3);
+ return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
+ Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
+ (Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
+ Op1.match(Ctx, N->getOperand(EO.FirstIndex)))) &&
+ Op2.match(Ctx, N->getOperand(EO.FirstIndex + 2));
+ }
+
+ return false;
+ }
+};
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P, false, false>
+m_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+ return TernaryOpc_match<T0_P, T1_P, T2_P, false, false>(ISD::SETCC, Op0, Op1,
+ Op2);
+}
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P, true, false>
+m_c_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+ return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, Op0, Op1,
+ Op2);
+}
+
// === Binary operations ===
template <typename LHS_P, typename RHS_P, bool Commutable = false,
bool ExcludeChain = false>
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index cece76f658307..2f1bcc9bed88b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2300,24 +2300,12 @@ static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
return true;
}
- if (N.getOpcode() != ISD::SETCC ||
- N.getValueType().getScalarType() != MVT::i1 ||
- cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
- return false;
-
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
- assert(Op0.getValueType() == Op1.getValueType());
-
- if (isNullOrNullSplat(Op0))
- Op = Op1;
- else if (isNullOrNullSplat(Op1))
- Op = Op0;
- else
+ if (N.getValueType().getScalarType() != MVT::i1 ||
+ !sd_match(
+ N, m_c_SetCC(m_Value(Op), m_Zero(), m_SpecificCondCode(ISD::SETNE))))
return false;
Known = DAG.computeKnownBits(Op);
-
return (Known.Zero | 1).isAllOnes();
}
@@ -2544,16 +2532,12 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
return SDValue();
// Match the zext operand as a setcc of a boolean.
- if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
- Z.getOperand(0).getValueType() != MVT::i1)
+ if (Z.getOperand(0).getValueType() != MVT::i1)
return SDValue();
// Match the compare as: setcc (X & 1), 0, eq.
- SDValue SetCC = Z.getOperand(0);
- ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
- if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
- SetCC.getOperand(0).getOpcode() != ISD::AND ||
- !isOneConstant(SetCC.getOperand(0).getOperand(1)))
+ if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(),
+ m_SpecificCondCode(ISD::SETEQ))))
return SDValue();
// We are adding/subtracting a constant and an inverted low bit. Turn that
@@ -2561,9 +2545,9 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
// add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
// sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
EVT VT = C.getValueType();
- SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
- SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
- DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
+ SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT);
+ SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT)
+ : DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
}
@@ -11554,13 +11538,12 @@ static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);
EVT VT = N->getValueType(0);
- if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
- return SDValue();
- SDValue Cond0 = N0.getOperand(0);
- SDValue Cond1 = N0.getOperand(1);
- ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
- if (VT != Cond0.getValueType())
+ SDValue Cond0, Cond1;
+ ISD::CondCode CC;
+ if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1),
+ m_CondCode(CC)))) ||
+ VT != Cond0.getValueType())
return SDValue();
// Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
>From 184bcfecf3654c8e0a7c0e10afe8b9a336c6024f Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 12 Jul 2024 18:42:04 +0100
Subject: [PATCH 2/3] Add SDPatternMatch::m_SetCC unit test coverage
---
.../CodeGen/SelectionDAGPatternMatchTest.cpp | 35 +++++++++++++++++++
1 file changed, 35 insertions(+)
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 46c385a0bc050..c0ee589447ff8 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -119,6 +119,41 @@ TEST_F(SelectionDAGPatternMatchTest, matchValueType) {
EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT()));
}
+TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {
+ SDLoc DL;
+ auto Int32VT = EVT::getIntegerVT(Context, 32);
+
+ SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+ SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+
+ SDValue ICMP_UGT = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETUGT);
+ SDValue ICMP_EQ01 = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETEQ);
+ SDValue ICMP_EQ10 = DAG->getSetCC(DL, MVT::i1, Op1, Op0, ISD::SETEQ);
+
+ using namespace SDPatternMatch;
+ ISD::CondCode CC;
+ EXPECT_TRUE(sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(),
+ m_SpecificCondCode(ISD::SETUGT))));
+ EXPECT_TRUE(
+ sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(), m_CondCode(CC))));
+ EXPECT_TRUE(CC == ISD::SETUGT);
+ EXPECT_FALSE(sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(),
+ m_SpecificCondCode(ISD::SETLE))));
+
+ EXPECT_TRUE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op0), m_Specific(Op1),
+ m_SpecificCondCode(ISD::SETEQ))));
+ EXPECT_TRUE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op1), m_Specific(Op0),
+ m_SpecificCondCode(ISD::SETEQ))));
+ EXPECT_FALSE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op1), m_Specific(Op0),
+ m_SpecificCondCode(ISD::SETEQ))));
+ EXPECT_FALSE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op0), m_Specific(Op1),
+ m_SpecificCondCode(ISD::SETEQ))));
+ EXPECT_TRUE(sd_match(ICMP_EQ01, m_c_SetCC(m_Specific(Op1), m_Specific(Op0),
+ m_SpecificCondCode(ISD::SETEQ))));
+ EXPECT_TRUE(sd_match(ICMP_EQ10, m_c_SetCC(m_Specific(Op0), m_Specific(Op1),
+ m_SpecificCondCode(ISD::SETEQ))));
+}
+
TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
SDLoc DL;
auto Int32VT = EVT::getIntegerVT(Context, 32);
>From e0c30085578f642de9f92881f081289de5c04fa4 Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Sun, 14 Jul 2024 12:24:09 +0100
Subject: [PATCH 3/3] Fix code formatting
---
.../unittests/CodeGen/SelectionDAGPatternMatchTest.cpp | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index c0ee589447ff8..a3d5e5f94b610 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -137,15 +137,15 @@ TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {
EXPECT_TRUE(
sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(), m_CondCode(CC))));
EXPECT_TRUE(CC == ISD::SETUGT);
- EXPECT_FALSE(sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(),
- m_SpecificCondCode(ISD::SETLE))));
+ EXPECT_FALSE(sd_match(
+ ICMP_UGT, m_SetCC(m_Value(), m_Value(), m_SpecificCondCode(ISD::SETLE))));
EXPECT_TRUE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op0), m_Specific(Op1),
- m_SpecificCondCode(ISD::SETEQ))));
+ m_SpecificCondCode(ISD::SETEQ))));
EXPECT_TRUE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op1), m_Specific(Op0),
- m_SpecificCondCode(ISD::SETEQ))));
+ m_SpecificCondCode(ISD::SETEQ))));
EXPECT_FALSE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op1), m_Specific(Op0),
- m_SpecificCondCode(ISD::SETEQ))));
+ m_SpecificCondCode(ISD::SETEQ))));
EXPECT_FALSE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op0), m_Specific(Op1),
m_SpecificCondCode(ISD::SETEQ))));
EXPECT_TRUE(sd_match(ICMP_EQ01, m_c_SetCC(m_Specific(Op1), m_Specific(Op0),
More information about the llvm-commits
mailing list