[llvm] f0bc411 - [X86] combineBasicSADPattern - pattern match various vXi8 ABDU patterns (#147570)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 9 00:08:26 PDT 2025


Author: Simon Pilgrim
Date: 2025-07-09T08:08:22+01:00
New Revision: f0bc41181c0fd03069ca63a3b8b0f85e3c7cb477

URL: https://github.com/llvm/llvm-project/commit/f0bc41181c0fd03069ca63a3b8b0f85e3c7cb477
DIFF: https://github.com/llvm/llvm-project/commit/f0bc41181c0fd03069ca63a3b8b0f85e3c7cb477.diff

LOG: [X86] combineBasicSADPattern - pattern match various vXi8 ABDU patterns (#147570)

We were previously limited to abs(sub(zext(),zext()) patterns, but add
handling for a number of other abdu patterns until a topological sorted
dag allows us to rely on a ABDU node having already been created.

Now that we don't just match zext() sources, I've generalised the
createPSADBW helper to explicitly zext/truncate to the expected vXi8
source type - it still assumes the sources are correct for a PSADBW
node.

Fixes #143456

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/sad.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index fd617f7062313..dd0df2c561919 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -46047,23 +46047,22 @@ static SDValue createVPDPBUSD(SelectionDAG &DAG, SDValue LHS, SDValue RHS,
                           DpBuilder, false);
 }
 
-// Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs
-// to these zexts.
-static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
-                            const SDValue &Zext1, const SDLoc &DL,
-                            const X86Subtarget &Subtarget) {
+// Create a PSADBW given two sources representable as zexts of vXi8.
+static SDValue createPSADBW(SelectionDAG &DAG, SDValue N0, SDValue N1,
+                            const SDLoc &DL, const X86Subtarget &Subtarget) {
   // Find the appropriate width for the PSADBW.
-  EVT InVT = Zext0.getOperand(0).getValueType();
-  unsigned RegSize = std::max(128u, (unsigned)InVT.getSizeInBits());
-
-  // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we
-  // fill in the missing vector elements with 0.
-  unsigned NumConcat = RegSize / InVT.getSizeInBits();
-  SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT));
-  Ops[0] = Zext0.getOperand(0);
+  EVT DstVT = N0.getValueType();
+  EVT SrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i8,
+                               DstVT.getVectorElementCount());
+  unsigned RegSize = std::max(128u, (unsigned)SrcVT.getSizeInBits());
+
+  // Widen the vXi8 vectors, padding with zero vector elements.
+  unsigned NumConcat = RegSize / SrcVT.getSizeInBits();
+  SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, SrcVT));
+  Ops[0] = DAG.getZExtOrTrunc(N0, DL, SrcVT);
   MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
   SDValue SadOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
-  Ops[0] = Zext1.getOperand(0);
+  Ops[0] = DAG.getZExtOrTrunc(N1, DL, SrcVT);
   SDValue SadOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
 
   // Actually build the SAD, split as 128/256/512 bits for SSE/AVX2/AVX512BW.
@@ -46073,7 +46072,7 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
     return DAG.getNode(X86ISD::PSADBW, DL, VT, Ops);
   };
   MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64);
-  return SplitOpsAndApply(DAG, Subtarget, DL, SadVT, { SadOp0, SadOp1 },
+  return SplitOpsAndApply(DAG, Subtarget, DL, SadVT, {SadOp0, SadOp1},
                           PSADBWBuilder);
 }
 
@@ -46372,9 +46371,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
     return SDValue();
 
   EVT ExtractVT = Extract->getValueType(0);
-  // Verify the type we're extracting is either i32 or i64.
-  // FIXME: Could support other types, but this is what we have coverage for.
-  if (ExtractVT != MVT::i32 && ExtractVT != MVT::i64)
+  if (ExtractVT != MVT::i8 && ExtractVT != MVT::i16 && ExtractVT != MVT::i32 &&
+      ExtractVT != MVT::i64)
     return SDValue();
 
   EVT VT = Extract->getOperand(0).getValueType();
@@ -46399,20 +46397,27 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
       Root.getOpcode() == ISD::ANY_EXTEND)
     Root = Root.getOperand(0);
 
-  // Check whether we have an abdu pattern.
-  // TODO: Add handling for ISD::ABDU.
-  SDValue Zext0, Zext1;
+  // Check whether we have an vXi8 abdu pattern.
+  // TODO: Just match ISD::ABDU once the DAG is topological sorted.
+  SDValue Src0, Src1;
   if (!sd_match(
           Root,
-          m_Abs(m_Sub(m_AllOf(m_Value(Zext0),
-                              m_ZExt(m_SpecificVectorElementVT(MVT::i8))),
-                      m_AllOf(m_Value(Zext1),
-                              m_ZExt(m_SpecificVectorElementVT(MVT::i8)))))))
+          m_AnyOf(
+              m_SpecificVectorElementVT(
+                  MVT::i8, m_c_BinOp(ISD::ABDU, m_Value(Src0), m_Value(Src1))),
+              m_SpecificVectorElementVT(
+                  MVT::i8, m_Sub(m_UMax(m_Value(Src0), m_Value(Src1)),
+                                 m_UMin(m_Deferred(Src0), m_Deferred(Src1)))),
+              m_Abs(
+                  m_Sub(m_AllOf(m_Value(Src0),
+                                m_ZExt(m_SpecificVectorElementVT(MVT::i8))),
+                        m_AllOf(m_Value(Src1),
+                                m_ZExt(m_SpecificVectorElementVT(MVT::i8))))))))
     return SDValue();
 
   // Create the SAD instruction.
   SDLoc DL(Extract);
-  SDValue SAD = createPSADBW(DAG, Zext0, Zext1, DL, Subtarget);
+  SDValue SAD = createPSADBW(DAG, Src0, Src1, DL, Subtarget);
 
   // If the original vector was wider than 8 elements, sum over the results
   // in the SAD vector.

diff  --git a/llvm/test/CodeGen/X86/sad.ll b/llvm/test/CodeGen/X86/sad.ll
index 084bd797df945..7364b15045b40 100644
--- a/llvm/test/CodeGen/X86/sad.ll
+++ b/llvm/test/CodeGen/X86/sad.ll
@@ -1184,11 +1184,6 @@ define i32 @PR143456(ptr %p0, ptr %p1) {
 ; SSE2:       # %bb.0:
 ; SSE2-NEXT:    movq {{.*#+}} xmm0 = mem[0],zero
 ; SSE2-NEXT:    movq {{.*#+}} xmm1 = mem[0],zero
-; SSE2-NEXT:    movdqa %xmm0, %xmm2
-; SSE2-NEXT:    pminub %xmm1, %xmm2
-; SSE2-NEXT:    pmaxub %xmm1, %xmm0
-; SSE2-NEXT:    psubb %xmm2, %xmm0
-; SSE2-NEXT:    pxor %xmm1, %xmm1
 ; SSE2-NEXT:    psadbw %xmm0, %xmm1
 ; SSE2-NEXT:    movd %xmm1, %eax
 ; SSE2-NEXT:    movzbl %al, %eax
@@ -1198,10 +1193,6 @@ define i32 @PR143456(ptr %p0, ptr %p1) {
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vmovq {{.*#+}} xmm0 = mem[0],zero
 ; AVX-NEXT:    vmovq {{.*#+}} xmm1 = mem[0],zero
-; AVX-NEXT:    vpminub %xmm1, %xmm0, %xmm2
-; AVX-NEXT:    vpmaxub %xmm1, %xmm0, %xmm0
-; AVX-NEXT:    vpsubb %xmm2, %xmm0, %xmm0
-; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
 ; AVX-NEXT:    vpsadbw %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vpextrb $0, %xmm0, %eax
 ; AVX-NEXT:    retq


        


More information about the llvm-commits mailing list