[llvm] 3dbff90 - [X86] matchPMADDWD/matchPMADDWD_2 - update to use SDPatternMatch matching. NFCI.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 6 02:01:52 PST 2024


Author: Simon Pilgrim
Date: 2024-12-06T10:01:37Z
New Revision: 3dbff90b16b5964b9fa468438ff40985be5c1ade

URL: https://github.com/llvm/llvm-project/commit/3dbff90b16b5964b9fa468438ff40985be5c1ade
DIFF: https://github.com/llvm/llvm-project/commit/3dbff90b16b5964b9fa468438ff40985be5c1ade.diff

LOG: [X86] matchPMADDWD/matchPMADDWD_2 - update to use SDPatternMatch matching. NFCI.

Prep work for #118433

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c18a4ac9acb1e4..f713f2ed209e1c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56447,9 +56447,11 @@ static SDValue combineADC(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
-static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
+static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
                             const SDLoc &DL, EVT VT,
                             const X86Subtarget &Subtarget) {
+  using namespace SDPatternMatch;
+
   // Example of pattern we try to detect:
   // t := (v8i32 mul (sext (v8i16 x0), (sext (v8i16 x1))))
   //(add (build_vector (extract_elt t, 0),
@@ -56464,15 +56466,16 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
   if (!Subtarget.hasSSE2())
     return SDValue();
 
-  if (Op0.getOpcode() != ISD::BUILD_VECTOR ||
-      Op1.getOpcode() != ISD::BUILD_VECTOR)
-    return SDValue();
-
   if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
       VT.getVectorNumElements() < 4 ||
       !isPowerOf2_32(VT.getVectorNumElements()))
     return SDValue();
 
+  SDValue Op0, Op1;
+  if (!sd_match(N, m_Add(m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op0)),
+                         m_AllOf(m_Opc(ISD::BUILD_VECTOR), m_Value(Op1)))))
+    return SDValue();
+
   // Check if one of Op0,Op1 is of the form:
   // (build_vector (extract_elt Mul, 0),
   //               (extract_elt Mul, 2),
@@ -56489,26 +56492,23 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
     SDValue Op0L = Op0->getOperand(i), Op1L = Op1->getOperand(i),
             Op0H = Op0->getOperand(i + 1), Op1H = Op1->getOperand(i + 1);
     // TODO: Be more tolerant to undefs.
-    if (Op0L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
-        Op1L.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
-        Op0H.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
-        Op1H.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
-      return SDValue();
-    auto *Const0L = dyn_cast<ConstantSDNode>(Op0L->getOperand(1));
-    auto *Const1L = dyn_cast<ConstantSDNode>(Op1L->getOperand(1));
-    auto *Const0H = dyn_cast<ConstantSDNode>(Op0H->getOperand(1));
-    auto *Const1H = dyn_cast<ConstantSDNode>(Op1H->getOperand(1));
-    if (!Const0L || !Const1L || !Const0H || !Const1H)
+    APInt Idx0L, Idx0H, Idx1L, Idx1H;
+    if (!sd_match(Op0L, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
+                                m_ConstInt(Idx0L))) ||
+        !sd_match(Op0H, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
+                                m_ConstInt(Idx0H))) ||
+        !sd_match(Op1L, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
+                                m_ConstInt(Idx1L))) ||
+        !sd_match(Op1H, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(),
+                                m_ConstInt(Idx1H))))
       return SDValue();
-    unsigned Idx0L = Const0L->getZExtValue(), Idx1L = Const1L->getZExtValue(),
-             Idx0H = Const0H->getZExtValue(), Idx1H = Const1H->getZExtValue();
     // Commutativity of mul allows factors of a product to reorder.
-    if (Idx0L > Idx1L)
+    if (Idx0L.getZExtValue() > Idx1L.getZExtValue())
       std::swap(Idx0L, Idx1L);
-    if (Idx0H > Idx1H)
+    if (Idx0H.getZExtValue() > Idx1H.getZExtValue())
       std::swap(Idx0H, Idx1H);
     // Commutativity of add allows pairs of factors to reorder.
-    if (Idx0L > Idx0H) {
+    if (Idx0L.getZExtValue() > Idx0H.getZExtValue()) {
       std::swap(Idx0L, Idx0H);
       std::swap(Idx1L, Idx1H);
     }
@@ -56555,13 +56555,12 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDValue Op0, SDValue Op1,
 // Attempt to turn this pattern into PMADDWD.
 // (add (mul (sext (build_vector)), (sext (build_vector))),
 //      (mul (sext (build_vector)), (sext (build_vector)))
-static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
+static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDNode *N,
                               const SDLoc &DL, EVT VT,
                               const X86Subtarget &Subtarget) {
-  if (!Subtarget.hasSSE2())
-    return SDValue();
+  using namespace SDPatternMatch;
 
-  if (N0.getOpcode() != ISD::MUL || N1.getOpcode() != ISD::MUL)
+  if (!Subtarget.hasSSE2())
     return SDValue();
 
   if (!VT.isVector() || VT.getVectorElementType() != MVT::i32 ||
@@ -56569,25 +56568,13 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
       !isPowerOf2_32(VT.getVectorNumElements()))
     return SDValue();
 
-  SDValue N00 = N0.getOperand(0);
-  SDValue N01 = N0.getOperand(1);
-  SDValue N10 = N1.getOperand(0);
-  SDValue N11 = N1.getOperand(1);
-
   // All inputs need to be sign extends.
   // TODO: Support ZERO_EXTEND from known positive?
-  if (N00.getOpcode() != ISD::SIGN_EXTEND ||
-      N01.getOpcode() != ISD::SIGN_EXTEND ||
-      N10.getOpcode() != ISD::SIGN_EXTEND ||
-      N11.getOpcode() != ISD::SIGN_EXTEND)
+  SDValue N00, N01, N10, N11;
+  if (!sd_match(N, m_Add(m_Mul(m_SExt(m_Value(N00)), m_SExt(m_Value(N01))),
+                         m_Mul(m_SExt(m_Value(N10)), m_SExt(m_Value(N11))))))
     return SDValue();
 
-  // Peek through the extends.
-  N00 = N00.getOperand(0);
-  N01 = N01.getOperand(0);
-  N10 = N10.getOperand(0);
-  N11 = N11.getOperand(0);
-
   // Must be extending from vXi16.
   EVT InVT = N00.getValueType();
   if (InVT.getVectorElementType() != MVT::i16 || N01.getValueType() != InVT ||
@@ -56614,34 +56601,26 @@ static SDValue matchPMADDWD_2(SelectionDAG &DAG, SDValue N0, SDValue N1,
     SDValue N10Elt = N10.getOperand(i);
     SDValue N11Elt = N11.getOperand(i);
     // TODO: Be more tolerant to undefs.
-    if (N00Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
-        N01Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
-        N10Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
-        N11Elt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
-      return SDValue();
-    auto *ConstN00Elt = dyn_cast<ConstantSDNode>(N00Elt.getOperand(1));
-    auto *ConstN01Elt = dyn_cast<ConstantSDNode>(N01Elt.getOperand(1));
-    auto *ConstN10Elt = dyn_cast<ConstantSDNode>(N10Elt.getOperand(1));
-    auto *ConstN11Elt = dyn_cast<ConstantSDNode>(N11Elt.getOperand(1));
-    if (!ConstN00Elt || !ConstN01Elt || !ConstN10Elt || !ConstN11Elt)
+    SDValue N00In, N01In, N10In, N11In;
+    APInt IdxN00, IdxN01, IdxN10, IdxN11;
+    if (!sd_match(N00Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N00In),
+                                  m_ConstInt(IdxN00))) ||
+        !sd_match(N01Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N01In),
+                                  m_ConstInt(IdxN01))) ||
+        !sd_match(N10Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N10In),
+                                  m_ConstInt(IdxN10))) ||
+        !sd_match(N11Elt, m_BinOp(ISD::EXTRACT_VECTOR_ELT, m_Value(N11In),
+                                  m_ConstInt(IdxN11))))
       return SDValue();
-    unsigned IdxN00 = ConstN00Elt->getZExtValue();
-    unsigned IdxN01 = ConstN01Elt->getZExtValue();
-    unsigned IdxN10 = ConstN10Elt->getZExtValue();
-    unsigned IdxN11 = ConstN11Elt->getZExtValue();
     // Add is commutative so indices can be reordered.
-    if (IdxN00 > IdxN10) {
+    if (IdxN00.getZExtValue() > IdxN10.getZExtValue()) {
       std::swap(IdxN00, IdxN10);
       std::swap(IdxN01, IdxN11);
     }
     // N0 indices be the even element. N1 indices must be the next odd element.
-    if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 ||
-        IdxN01 != 2 * i || IdxN11 != 2 * i + 1)
+    if (IdxN00 != 2 * i || IdxN10 != 2 * i + 1 || IdxN01 != 2 * i ||
+        IdxN11 != 2 * i + 1)
       return SDValue();
-    SDValue N00In = N00Elt.getOperand(0);
-    SDValue N01In = N01Elt.getOperand(0);
-    SDValue N10In = N10Elt.getOperand(0);
-    SDValue N11In = N11Elt.getOperand(0);
 
     // First time we find an input capture it.
     if (!In0) {
@@ -56815,9 +56794,9 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
   if (SDValue Select = pushAddIntoCmovOfConsts(N, DL, DAG, Subtarget))
     return Select;
 
-  if (SDValue MAdd = matchPMADDWD(DAG, Op0, Op1, DL, VT, Subtarget))
+  if (SDValue MAdd = matchPMADDWD(DAG, N, DL, VT, Subtarget))
     return MAdd;
-  if (SDValue MAdd = matchPMADDWD_2(DAG, Op0, Op1, DL, VT, Subtarget))
+  if (SDValue MAdd = matchPMADDWD_2(DAG, N, DL, VT, Subtarget))
     return MAdd;
   if (SDValue MAdd = combineAddOfPMADDWD(DAG, Op0, Op1, DL, VT))
     return MAdd;


        


More information about the llvm-commits mailing list