[llvm] 3dbff90 - [X86] matchPMADDWD/matchPMADDWD_2 - update to use SDPatternMatch matching. NFCI.
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 6 02:01:52 PST 2024
Author: Simon Pilgrim
Date: 2024-12-06T10:01:37Z
New Revision: 3dbff90b16b5964b9fa468438ff40985be5c1ade
URL: https://github.com/llvm/llvm-project/commit/3dbff90b16b5964b9fa468438ff40985be5c1ade
DIFF: https://github.com/llvm/llvm-project/commit/3dbff90b16b5964b9fa468438ff40985be5c1ade.diff
LOG: [X86] matchPMADDWD/matchPMADDWD_2 - update to use SDPatternMatch matching. NFCI.
Prep work for #118433
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c18a4ac9acb1e4..f713f2ed209e1c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56447,9 +56447,11 @@ static SDValue combineADC(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
-static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
+static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
const SDLoc &DL, EVT VT,
const X86Subtarget &Subtarget) {
+ using namespace SDPatternMatch;
+
// Example of pattern we try to detect:
// t := (v8i32 mul (sext (v8i16 x0), (sext (v8i16 x1))))
//(add (build_vector (extract_elt t, 0),
@@ -56464,15 +56466,16 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
if (!Subtarget.hasSSE2())
return SDValue();
- if (Op0.getOpcode() != ISD::BUILD_VECTOR ||
- Op1.getOpcode() != ISD::BUILD_VECTOR)
- return SDValue();
-
if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
VT.getVectorNumElements() < 4 ||
!isPowerOf2_32(VT.getVectorNumElements()))
return SDValue();
+ SDValue Op0, Op1;
+ if (!sd_match(N, m_Add(m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op0)),
+ m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op1)))))
+ return SDValue();
+
// Check if one of Op0,Op1 is of the form:
// (build_vector (extract_elt Mul, 0),
// (extract_elt Mul, 2),
@@ -56489,26 +56492,23 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
SDValue Op0L = Op0->getOperand(i), Op1L = Op1->getOperand(i),
Op0H = Op0->getOperand(i + 1), Op1H = Op1->getOperand(i + 1);
// TODO: Be more tolerant to undefs.
- if (Op0L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
- Op1L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
- Op0H.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
- Op1H.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
- return SDValue();
- auto *Const0L = dyn_cast<ConstantSDNode>(Op0L->getOperand(1));
- auto *Const1L = dyn_cast<ConstantSDNode>(Op1L->getOperand(1));
- auto *Const0H = dyn_cast<ConstantSDNode>(Op0H->getOperand(1));
- auto *Const1H = dyn_cast<ConstantSDNode>(Op1H->getOperand(1));
- if (!Const0L || !Const1L || !Const0H || !Const1H)
+ APInt Idx0L, Idx0H, Idx1L, Idx1H;
+ if (!sd_match(Op0L, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
+ m_ConstInt(Idx0L))) ||
+ !sd_match(Op0H, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
+ m_ConstInt(Idx0H))) ||
+ !sd_match(Op1L, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
+ m_ConstInt(Idx1L))) ||
+ !sd_match(Op1H, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
+ m_ConstInt(Idx1H))))
return SDValue();
- unsigned Idx0L = Const0L->getZExtValue(), Idx1L = Const1L->getZExtValue(),
- Idx0H = Const0H->getZExtValue(), Idx1H = Const1H->getZExtValue();
// Commutativity of mul allows factors of a product to reorder.
- if (Idx0L > Idx1L)
+ if (Idx0L.getZExtValue() > Idx1L.getZExtValue())
std::swap(Idx0L, Idx1L);
- if (Idx0H > Idx1H)
+ if (Idx0H.getZExtValue() > Idx1H.getZExtValue())
std::swap(Idx0H, Idx1H);
// Commutativity of add allows pairs of factors to reorder.
- if (Idx0L > Idx0H) {
+ if (Idx0L.getZExtValue() > Idx0H.getZExtValue()) {
std::swap(Idx0L, Idx0H);
std::swap(Idx1L, Idx1H);
}
@@ -56555,13 +56555,12 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
// Attempt to turn this pattern into PMADDWD.
// (add (mul (sext (build_vector)), (sext (build_vector))),
// (mul (sext (build_vector)), (sext (build_vector)))
-static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
+static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDNode *N,
const SDLoc &DL, EVT VT,
const X86Subtarget &Subtarget) {
- if (!Subtarget.hasSSE2())
- return SDValue();
+ using namespace SDPatternMatch;
- if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
+ if (!Subtarget.hasSSE2())
return SDValue();
if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
@@ -56569,25 +56568,13 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
!isPowerOf2_32(VT.getVectorNumElements()))
return SDValue();
- SDValue N00 = N0.getOperand(0);
- SDValue N01 = N0.getOperand(1);
- SDValue N10 = N1.getOperand(0);
- SDValue N11 = N1.getOperand(1);
-
// All inputs need to be sign extends.
// TODO: Support ZERO_EXTEND from known positive?
- if (N00.getOpcode() != ISD::SIGN_EXTEND ||
- N01.getOpcode() != ISD::SIGN_EXTEND ||
- N10.getOpcode() != ISD::SIGN_EXTEND ||
- N11.getOpcode() != ISD::SIGN_EXTEND)
+ SDValue N00, N01, N10, N11;
+ if (!sd_match(N, m_Add(m_Mul(m_SExt(m_Value(N00)), m_SExt(m_Value(N01))),
+ m_Mul(m_SExt(m_Value(N10)), m_SExt(m_Value(N11))))))
return SDValue();
- // Peek through the extends.
- N00 = N00.getOperand(0);
- N01 = N01.getOperand(0);
- N10 = N10.getOperand(0);
- N11 = N11.getOperand(0);
-
// Must be extending from vXi16.
EVT InVT = N00.getValueType();
if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT ||
@@ -56614,34 +56601,26 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
SDValue N10Elt = N10.getOperand(i);
SDValue N11Elt = N11.getOperand(i);
// TODO: Be more tolerant to undefs.
- if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
- N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
- N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
- N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
- return SDValue();
- auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
- auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
- auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
- auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1));
- if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt)
+ SDValue N00In, N01In, N10In, N11In;
+ APInt IdxN00, IdxN01, IdxN10, IdxN11;
+ if (!sd_match(N00Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N00In),
+ m_ConstInt(IdxN00))) ||
+ !sd_match(N01Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N01In),
+ m_ConstInt(IdxN01))) ||
+ !sd_match(N10Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N10In),
+ m_ConstInt(IdxN10))) ||
+ !sd_match(N11Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N11In),
+ m_ConstInt(IdxN11))))
return SDValue();
- unsigned IdxN00 = ConstN00Elt->getZExtValue();
- unsigned IdxN01 = ConstN01Elt->getZExtValue();
- unsigned IdxN10 = ConstN10Elt->getZExtValue();
- unsigned IdxN11 = ConstN11Elt->getZExtValue();
// Add is commutative so indices can be reordered.
- if (IdxN00 > IdxN10) {
+ if (IdxN00.getZExtValue() > IdxN10.getZExtValue()) {
std::swap(IdxN00, IdxN10);
std::swap(IdxN01, IdxN11);
}
// N0 indices be the even element. N1 indices must be the next odd element.
- if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
- IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
+ if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || IdxN01 != 2 * i ||
+ IdxN11 != 2 * i + 1)
return SDValue();
- SDValue N00In = N00Elt.getOperand(0);
- SDValue N01In = N01Elt.getOperand(0);
- SDValue N10In = N10Elt.getOperand(0);
- SDValue N11In = N11Elt.getOperand(0);
// First time we find an input capture it.
if (!In0) {
@@ -56815,9 +56794,9 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
if (SDValue Select = pushAddIntoCmovOfConsts(N, DL, DAG, Subtarget))
return Select;
- if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1, DL, VT, Subtarget))
+ if (SDValue MAdd = matchPMADDWD(DAG, N, DL, VT, Subtarget))
return MAdd;
- if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, DL, VT, Subtarget))
+ if (SDValue MAdd = matchPMADDWD_2(DAG, N, DL, VT, Subtarget))
return MAdd;
if (SDValue MAdd = combineAddOfPMADDWD(DAG, Op0, Op1, DL, VT))
return MAdd;
More information about the llvm-commits
mailing list