[llvm] [DAG] Add SDPatternMatch::m_SetCC and update some combines to use it (PR #98646)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 12 08:05:09 PDT 2024
https://github.com/RKSimon created https://github.com/llvm/llvm-project/pull/98646
The plan is to add more TernaryOp in the future (SELECT/VSELECT and FMA in particular)
>From 8e863f7fd7ecef6d7b8f25525cb7f085e2d46d9d Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 12 Jul 2024 16:02:37 +0100
Subject: [PATCH] [DAG] Add SDPatternMatch::m_SetCC and update some combines to
use it
---
llvm/include/llvm/CodeGen/SDPatternMatch.h | 43 ++++++++++++++++++
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 45 ++++++-------------
2 files changed, 57 insertions(+), 31 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index f39fbd95b3beb..ee2c0610c92b8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -447,6 +447,49 @@ template <> struct EffectiveOperands<false> {
explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {}
};
+// === Ternary operations ===
+template <typename T0_P, typename T1_P, typename T2_P, bool Commutable = false,
+ bool ExcludeChain = false>
+struct TernaryOpc_match {
+ unsigned Opcode;
+ T0_P Op0;
+ T1_P Op1;
+ T2_P Op2;
+
+ TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1,
+ const T2_P &Op2)
+ : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {}
+
+ template <typename MatchContext>
+ bool match(const MatchContext &Ctx, SDValue N) {
+ if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
+ EffectiveOperands<ExcludeChain> EO(N);
+ assert(EO.Size == 3);
+ return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) &&
+ Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
+ (Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
+ Op1.match(Ctx, N->getOperand(EO.FirstIndex)))) &&
+ Op2.match(Ctx, N->getOperand(EO.FirstIndex + 2));
+ }
+
+ return false;
+ }
+};
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P, false>
+m_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+ return TernaryOpc_match<T0_P, T1_P, T2_P, false, false>(ISD::SETCC, Op0, Op1,
+ Op2);
+}
+
+template <typename T0_P, typename T1_P, typename T2_P>
+inline TernaryOpc_match<T0_P, T1_P, T2_P, false>
+m_c_SetCC(const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) {
+ return TernaryOpc_match<T0_P, T1_P, T2_P, true, false>(ISD::SETCC, Op0, Op1,
+ Op2);
+}
+
// === Binary operations ===
template <typename LHS_P, typename RHS_P, bool Commutable = false,
bool ExcludeChain = false>
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 428cdda21cd41..9185c0176de12 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2300,24 +2300,12 @@ static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
return true;
}
- if (N.getOpcode() != ISD::SETCC ||
- N.getValueType().getScalarType() != MVT::i1 ||
- cast<CondCodeSDNode>(N.getOperand(2))->get() != ISD::SETNE)
- return false;
-
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
- assert(Op0.getValueType() == Op1.getValueType());
-
- if (isNullOrNullSplat(Op0))
- Op = Op1;
- else if (isNullOrNullSplat(Op1))
- Op = Op0;
- else
+ if (N.getValueType().getScalarType() != MVT::i1 ||
+ !sd_match(N,
+ m_c_SetCC(m_Value(Op), m_Zero(), m_SpecificCondCode(ISD::SETNE))))
return false;
Known = DAG.computeKnownBits(Op);
-
return (Known.Zero | 1).isAllOnes();
}
@@ -2544,16 +2532,12 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
return SDValue();
// Match the zext operand as a setcc of a boolean.
- if (Z.getOperand(0).getOpcode() != ISD::SETCC ||
- Z.getOperand(0).getValueType() != MVT::i1)
+ if (Z.getOperand(0).getValueType() != MVT::i1)
return SDValue();
// Match the compare as: setcc (X & 1), 0, eq.
- SDValue SetCC = Z.getOperand(0);
- ISD::CondCode CC = cast<CondCodeSDNode>(SetCC->getOperand(2))->get();
- if (CC != ISD::SETEQ || !isNullConstant(SetCC.getOperand(1)) ||
- SetCC.getOperand(0).getOpcode() != ISD::AND ||
- !isOneConstant(SetCC.getOperand(0).getOperand(1)))
+ if (!sd_match(Z.getOperand(0), m_SetCC(m_And(m_Value(), m_One()), m_Zero(),
+ m_SpecificCondCode(ISD::SETEQ))))
return SDValue();
// We are adding/subtracting a constant and an inverted low bit. Turn that
@@ -2561,9 +2545,9 @@ static SDValue foldAddSubBoolOfMaskedVal(SDNode *N, const SDLoc &DL,
// add (zext i1 (seteq (X & 1), 0)), C --> sub C+1, (zext (X & 1))
// sub C, (zext i1 (seteq (X & 1), 0)) --> add C-1, (zext (X & 1))
EVT VT = C.getValueType();
- SDValue LowBit = DAG.getZExtOrTrunc(SetCC.getOperand(0), DL, VT);
- SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT) :
- DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
+ SDValue LowBit = DAG.getZExtOrTrunc(Z.getOperand(0).getOperand(0), DL, VT);
+ SDValue C1 = IsAdd ? DAG.getConstant(CN->getAPIntValue() + 1, DL, VT)
+ : DAG.getConstant(CN->getAPIntValue() - 1, DL, VT);
return DAG.getNode(IsAdd ? ISD::SUB : ISD::ADD, DL, VT, C1, LowBit);
}
@@ -11554,13 +11538,12 @@ static SDValue foldVSelectToSignBitSplatMask(SDNode *N, SelectionDAG &DAG) {
SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);
EVT VT = N->getValueType(0);
- if (N0.getOpcode() != ISD::SETCC || !N0.hasOneUse())
- return SDValue();
- SDValue Cond0 = N0.getOperand(0);
- SDValue Cond1 = N0.getOperand(1);
- ISD::CondCode CC = cast<CondCodeSDNode>(N0.getOperand(2))->get();
- if (VT != Cond0.getValueType())
+ SDValue Cond0, Cond1;
+ ISD::CondCode CC;
+ if (!sd_match(N0, m_OneUse(m_SetCC(m_Value(Cond0), m_Value(Cond1),
+ m_CondCode(CC)))) ||
+ VT != Cond0.getValueType())
return SDValue();
// Match a signbit check of Cond0 as "Cond0 s<0". Swap select operands if the
More information about the llvm-commits
mailing list