[llvm] [DAG] Add SDPatternMatch::m_SetCC and update some combines to use it (PR #98646)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 12 08:05:09 PDT 2024


https://github.com/RKSimon created https://github.com/llvm/llvm-project/pull/98646

The plan is to add more TernaryOp in the future (SELECT/VSELECT and FMA in particular)

>From 8e863f7fd7ecef6d7b8f25525cb7f085e2d46d9d Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 12 Jul 2024 16:02:37 +0100
Subject: [PATCH] [DAG] Add SDPatternMatch::m_SetCC and update some combines to
 use it

---
 llvm/include/llvm/CodeGen/SDPatternMatch.h    | 43 ++++++++++++++++++
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 45 ++++++-------------
 2 files changed, 57 insertions(+), 31 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index f39fbd95b3beb..ee2c0610c92b8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -447,6 +447,49 @@ template <> struct EffectiveOperands<false> {
   explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
 };
 
+// === Ternary operations ===
+template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
+          bool ExcludeChain = false>
+struct TernaryOpc_match {
+  unsigned Opcode;
+  T0_P Op0;
+  T1_P Op1;
+  T2_P Op2;
+
+  TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
+                   const T2_P &Op2)
+      : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &Ctx, SDValue N) {
+    if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
+      EffectiveOperands<ExcludeChain> EO(N);
+      assert(EO.Size == 3);
+      return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
+               Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
+              (Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
+               Op1.match(Ctx, N->getOperand(EO.FirstIndex)))) &&
+             Op2.match(Ctx, N->getOperand(EO.FirstIndex + 2));
+    }
+
+    return false;
+  }
+};
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P, false>
+m_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+  return TernaryOpc_match<T0_P, T1_P, T2_P, false, false>(ISD::SETCC, Op0, Op1,
+                                                          Op2);
+}
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P, false>
+m_c_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+  return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, Op0, Op1,
+                                                         Op2);
+}
+
 // === Binary operations ===
 template <typename LHS_P, typename RHS_P, bool Commutable = false,
           bool ExcludeChain = false>
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 428cdda21cd41..9185c0176de12 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2300,24 +2300,12 @@ static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
     return true;
   }
 
-  if (N.getOpcode() != ISD::SETCC ||
-      N.getValueType().getScalarType() != MVT::i1 ||
-      cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
-    return false;
-
-  SDValue Op0 = N->getOperand(0);
-  SDValue Op1 = N->getOperand(1);
-  assert(Op0.getValueType() == Op1.getValueType());
-
-  if (isNullOrNullSplat(Op0))
-    Op = Op1;
-  else if (isNullOrNullSplat(Op1))
-    Op = Op0;
-  else
+  if (N.getValueType().getScalarType() != MVT::i1 ||
+      !sd_match(N,
+                m_c_SetCC(m_Value(Op), m_Zero(), m_SpecificCondCode(ISD::SETNE))))
     return false;
 
   Known = DAG.computeKnownBits(Op);
-
   return (Known.Zero | 1).isAllOnes();
 }
 
@@ -2544,16 +2532,12 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
     return SDValue();
 
   // Match the zext operand as a setcc of a boolean.
-  if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
-      Z.getOperand(0).getValueType() != MVT::i1)
+  if (Z.getOperand(0).getValueType() != MVT::i1)
     return SDValue();
 
   // Match the compare as: setcc (X & 1), 0, eq.
-  SDValue SetCC = Z.getOperand(0);
-  ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
-  if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
-      SetCC.getOperand(0).getOpcode() != ISD::AND ||
-      !isOneConstant(SetCC.getOperand(0).getOperand(1)))
+  if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(),
+                                         m_SpecificCondCode(ISD::SETEQ))))
     return SDValue();
 
   // We are adding/subtracting a constant and an inverted low bit. Turn that
@@ -2561,9 +2545,9 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
   // add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
   // sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
   EVT VT = C.getValueType();
-  SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
-  SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
-                       DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
+  SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT);
+  SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT)
+                     : DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
   return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
 }
 
@@ -11554,13 +11538,12 @@ static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
   SDValue N1 = N->getOperand(1);
   SDValue N2 = N->getOperand(2);
   EVT VT = N->getValueType(0);
-  if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
-    return SDValue();
 
-  SDValue Cond0 = N0.getOperand(0);
-  SDValue Cond1 = N0.getOperand(1);
-  ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
-  if (VT != Cond0.getValueType())
+  SDValue Cond0, Cond1;
+  ISD::CondCode CC;
+  if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1),
+                                     m_CondCode(CC)))) ||
+      VT != Cond0.getValueType())
     return SDValue();
 
   // Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the



More information about the llvm-commits mailing list