[llvm] 61a4e1e - [DAG] Add SDPatternMatch::m_SetCC and update some combines to use it (#98646)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 14 09:18:47 PDT 2024


Author: Simon Pilgrim
Date: 2024-07-14T17:18:43+01:00
New Revision: 61a4e1e70f07c89bd890ef2bc61a818e6a321d2d

URL: https://github.com/llvm/llvm-project/commit/61a4e1e70f07c89bd890ef2bc61a818e6a321d2d
DIFF: https://github.com/llvm/llvm-project/commit/61a4e1e70f07c89bd890ef2bc61a818e6a321d2d.diff

LOG: [DAG] Add SDPatternMatch::m_SetCC and update some combines to use it (#98646)

The plan is to add more TernaryOp in the future (SELECT/VSELECT and FMA in particular)

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SDPatternMatch.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index f39fbd95b3beb..07204d1f48c24 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -447,6 +447,49 @@ template <> struct EffectiveOperands<false> {
   explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
 };
 
+// === Ternary operations ===
+template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
+          bool ExcludeChain = false>
+struct TernaryOpc_match {
+  unsigned Opcode;
+  T0_P Op0;
+  T1_P Op1;
+  T2_P Op2;
+
+  TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
+                   const T2_P &Op2)
+      : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
+      EffectiveOperands<ExcludeChain> EO(N);
+      assert(EO.Size == 3);
+      return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
+               Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
+              (Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
+               Op1.match(Ctx, N->getOperand(EO.FirstIndex)))) &&
+             Op2.match(Ctx, N->getOperand(EO.FirstIndex + 2));
+    }
+
+    return false;
+  }
+};
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P, false, false>
+m_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+  return TernaryOpc_match<T0_P, T1_P, T2_P, false, false>(ISD::SETCC, Op0, Op1,
+                                                          Op2);
+}
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P, true, false>
+m_c_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+  return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, Op0, Op1,
+                                                         Op2);
+}
+
 // === Binary operations ===
 template <typename LHS_P, typename RHS_P, bool Commutable = false,
           bool ExcludeChain = false>

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index cece76f658307..2f1bcc9bed88b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2300,24 +2300,12 @@ static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
     return true;
   }
 
-  if (N.getOpcode() != ISD::SETCC ||
-      N.getValueType().getScalarType() != MVT::i1 ||
-      cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
-    return false;
-
-  SDValue Op0 = N->getOperand(0);
-  SDValue Op1 = N->getOperand(1);
-  assert(Op0.getValueType() == Op1.getValueType());
-
-  if (isNullOrNullSplat(Op0))
-    Op = Op1;
-  else if (isNullOrNullSplat(Op1))
-    Op = Op0;
-  else
+  if (N.getValueType().getScalarType() != MVT::i1 ||
+      !sd_match(
+          N, m_c_SetCC(m_Value(Op), m_Zero(), m_SpecificCondCode(ISD::SETNE))))
     return false;
 
   Known = DAG.computeKnownBits(Op);
-
   return (Known.Zero | 1).isAllOnes();
 }
 
@@ -2544,16 +2532,12 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
     return SDValue();
 
   // Match the zext operand as a setcc of a boolean.
-  if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
-      Z.getOperand(0).getValueType() != MVT::i1)
+  if (Z.getOperand(0).getValueType() != MVT::i1)
     return SDValue();
 
   // Match the compare as: setcc (X & 1), 0, eq.
-  SDValue SetCC = Z.getOperand(0);
-  ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
-  if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
-      SetCC.getOperand(0).getOpcode() != ISD::AND ||
-      !isOneConstant(SetCC.getOperand(0).getOperand(1)))
+  if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(),
+                                         m_SpecificCondCode(ISD::SETEQ))))
     return SDValue();
 
   // We are adding/subtracting a constant and an inverted low bit. Turn that
@@ -2561,9 +2545,9 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
   EVT VT = C.getValueType();
-  SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
-  SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
-                       DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
+  SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT);
+  SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT)
+                     : DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
 }
 
@@ -11554,13 +11538,12 @@ static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
   SDValue N1 = N->getOperand(1);
   SDValue N2 = N->getOperand(2);
   EVT VT = N->getValueType(0);
-  if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
-    return SDValue();
 
-  SDValue Cond0 = N0.getOperand(0);
-  SDValue Cond1 = N0.getOperand(1);
-  ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
-  if (VT != Cond0.getValueType())
+  SDValue Cond0, Cond1;
+  ISD::CondCode CC;
+  if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1),
+                                     m_CondCode(CC)))) ||
+      VT != Cond0.getValueType())
     return SDValue();
 
   // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the

diff  --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 46c385a0bc050..a3d5e5f94b610 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -119,6 +119,41 @@ TEST_F(SelectionDAGPatternMatchTest, matchValueType) {
   EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT()));
 }
 
+TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {
+  SDLoc DL;
+  auto Int32VT = EVT::getIntegerVT(Context, 32);
+
+  SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
+  SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
+
+  SDValue ICMP_UGT = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETUGT);
+  SDValue ICMP_EQ01 = DAG->getSetCC(DL, MVT::i1, Op0, Op1, ISD::SETEQ);
+  SDValue ICMP_EQ10 = DAG->getSetCC(DL, MVT::i1, Op1, Op0, ISD::SETEQ);
+
+  using namespace SDPatternMatch;
+  ISD::CondCode CC;
+  EXPECT_TRUE(sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(),
+                                         m_SpecificCondCode(ISD::SETUGT))));
+  EXPECT_TRUE(
+      sd_match(ICMP_UGT, m_SetCC(m_Value(), m_Value(), m_CondCode(CC))));
+  EXPECT_TRUE(CC == ISD::SETUGT);
+  EXPECT_FALSE(sd_match(
+      ICMP_UGT, m_SetCC(m_Value(), m_Value(), m_SpecificCondCode(ISD::SETLE))));
+
+  EXPECT_TRUE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op0), m_Specific(Op1),
+                                          m_SpecificCondCode(ISD::SETEQ))));
+  EXPECT_TRUE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op1), m_Specific(Op0),
+                                          m_SpecificCondCode(ISD::SETEQ))));
+  EXPECT_FALSE(sd_match(ICMP_EQ01, m_SetCC(m_Specific(Op1), m_Specific(Op0),
+                                           m_SpecificCondCode(ISD::SETEQ))));
+  EXPECT_FALSE(sd_match(ICMP_EQ10, m_SetCC(m_Specific(Op0), m_Specific(Op1),
+                                           m_SpecificCondCode(ISD::SETEQ))));
+  EXPECT_TRUE(sd_match(ICMP_EQ01, m_c_SetCC(m_Specific(Op1), m_Specific(Op0),
+                                            m_SpecificCondCode(ISD::SETEQ))));
+  EXPECT_TRUE(sd_match(ICMP_EQ10, m_c_SetCC(m_Specific(Op0), m_Specific(Op1),
+                                            m_SpecificCondCode(ISD::SETEQ))));
+}
+
 TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   SDLoc DL;
   auto Int32VT = EVT::getIntegerVT(Context, 32);


        


More information about the llvm-commits mailing list