[llvm] [X86] Improve KnownBits for X86ISD::PSADBW nodes (PR #83830)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 5 06:57:22 PST 2024


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/83830

>From 2bb14ec185ba697e5091eb11b10f8c8745eeb74a Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 1 Mar 2024 11:00:09 +0000
Subject: [PATCH] [X86] Improve KnownBits for X86ISD::PSADBW nodes

Don't just return the known zero upperbits, compute the absdiff Knownbits and perform the horizontal sum
---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 51 ++++++++++++++--
 llvm/test/CodeGen/X86/psadbw.ll         | 77 ++++++-------------------
 llvm/test/CodeGen/X86/sad.ll            | 16 +----
 3 files changed, 68 insertions(+), 76 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6eaaec407dbb08..9f2d3c1ec4740a 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -36739,6 +36739,23 @@ X86TargetLowering::targetShrinkDemandedConstant(SDValue Op,
   return TLO.CombineTo(Op, NewOp);
 }
 
+static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
+                                      KnownBits &Known,
+                                      const APInt &DemandedElts,
+                                      const SelectionDAG &DAG, unsigned Depth) {
+  KnownBits Known2;
+  unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+  APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+  Known = DAG.computeKnownBits(RHS, DemandedSrcElts, Depth + 1);
+  Known2 = DAG.computeKnownBits(LHS, DemandedSrcElts, Depth + 1);
+  Known = KnownBits::absdiff(Known, Known2).zext(16);
+  // Known = (((D0 + D1) + (D2 + D3)) + ((D4 + D5) + (D6 + D7)))
+  Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, Known, Known);
+  Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, Known, Known);
+  Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, Known, Known);
+  Known = Known.zext(64);
+}
+
 void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
                                                       KnownBits &Known,
                                                       const APInt &DemandedElts,
@@ -36888,12 +36905,13 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     break;
   }
   case X86ISD::PSADBW: {
+    SDValue LHS = Op.getOperand(0);
+    SDValue RHS = Op.getOperand(1);
     assert(VT.getScalarType() == MVT::i64 &&
-           Op.getOperand(0).getValueType().getScalarType() == MVT::i8 &&
+           LHS.getValueType() == RHS.getValueType() &&
+           LHS.getValueType().getScalarType() == MVT::i8 &&
            "Unexpected PSADBW types");
-
-    // PSADBW - fills low 16 bits and zeros upper 48 bits of each i64 result.
-    Known.Zero.setBitsFrom(16);
+    computeKnownBitsForPSADBW(LHS, RHS, Known, DemandedElts, DAG, Depth);
     break;
   }
   case X86ISD::PCMPGT:
@@ -37047,6 +37065,23 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     }
     break;
   }
+  case ISD::INTRINSIC_WO_CHAIN: {
+    switch (Op->getConstantOperandVal(0)) {
+    case Intrinsic::x86_sse2_psad_bw:
+    case Intrinsic::x86_avx2_psad_bw:
+    case Intrinsic::x86_avx512_psad_bw_512: {
+      SDValue LHS = Op.getOperand(1);
+      SDValue RHS = Op.getOperand(2);
+      assert(VT.getScalarType() == MVT::i64 &&
+             LHS.getValueType() == RHS.getValueType() &&
+             LHS.getValueType().getScalarType() == MVT::i8 &&
+             "Unexpected PSADBW types");
+      computeKnownBitsForPSADBW(LHS, RHS, Known, DemandedElts, DAG, Depth);
+      break;
+    }
+    }
+    break;
+  }
   }
 
   // Handle target shuffles.
@@ -54905,6 +54940,7 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
 }
 
 static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
+                                    TargetLowering::DAGCombinerInfo &DCI,
                                     const X86Subtarget &Subtarget) {
   MVT VT = N->getSimpleValueType(0);
   SDLoc DL(N);
@@ -54916,6 +54952,11 @@ static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
       return DAG.getConstant(0, DL, VT);
   }
 
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (TLI.SimplifyDemandedBits(
+          SDValue(N, 0), APInt::getAllOnes(VT.getScalarSizeInBits()), DCI))
+    return SDValue(N, 0);
+
   return SDValue();
 }
 
@@ -56639,7 +56680,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::MGATHER:
   case ISD::MSCATTER:       return combineGatherScatter(N, DAG, DCI);
   case X86ISD::PCMPEQ:
-  case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, Subtarget);
+  case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, DCI, Subtarget);
   case X86ISD::PMULDQ:
   case X86ISD::PMULUDQ:     return combinePMULDQ(N, DAG, DCI, Subtarget);
   case X86ISD::VPMADDUBSW:
diff --git a/llvm/test/CodeGen/X86/psadbw.ll b/llvm/test/CodeGen/X86/psadbw.ll
index 8141b22d321f4d..d421a44f51a034 100644
--- a/llvm/test/CodeGen/X86/psadbw.ll
+++ b/llvm/test/CodeGen/X86/psadbw.ll
@@ -50,30 +50,20 @@ define i64 @combine_psadbw_demandedelt(<16 x i8> %0, <16 x i8> %1) nounwind {
 define <2 x i64> @combine_psadbw_cmp_knownbits(<16 x i8> %a0) nounwind {
 ; X86-SSE-LABEL: combine_psadbw_cmp_knownbits:
 ; X86-SSE:       # %bb.0:
-; X86-SSE-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
-; X86-SSE-NEXT:    pxor %xmm1, %xmm1
-; X86-SSE-NEXT:    psadbw %xmm0, %xmm1
-; X86-SSE-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[0,0,2,2]
-; X86-SSE-NEXT:    por {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
-; X86-SSE-NEXT:    pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
+; X86-SSE-NEXT:    xorps %xmm0, %xmm0
 ; X86-SSE-NEXT:    retl
 ;
 ; X64-SSE-LABEL: combine_psadbw_cmp_knownbits:
 ; X64-SSE:       # %bb.0:
 ; X64-SSE-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
 ; X64-SSE-NEXT:    pxor %xmm1, %xmm1
-; X64-SSE-NEXT:    psadbw %xmm0, %xmm1
-; X64-SSE-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[0,0,2,2]
-; X64-SSE-NEXT:    por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; X64-SSE-NEXT:    psadbw %xmm1, %xmm0
 ; X64-SSE-NEXT:    pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
 ; X64-SSE-NEXT:    retq
 ;
 ; AVX2-LABEL: combine_psadbw_cmp_knownbits:
 ; AVX2:       # %bb.0:
-; AVX2-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX2-NEXT:    vpsadbw %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vpcmpgtq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX2-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX2-NEXT:    retq
   %mask = and <16 x i8> %a0, <i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3>
   %sad = tail call <2 x i64> @llvm.x86.sse2.psad.bw(<16 x i8> %mask, <16 x i8> zeroinitializer)
@@ -82,28 +72,15 @@ define <2 x i64> @combine_psadbw_cmp_knownbits(<16 x i8> %a0) nounwind {
   ret <2 x i64> %ext
 }
 
-; TODO: No need to scalarize the sitofp as the PSADBW results are smaller than i32.
+; No need to scalarize the sitofp as the PSADBW results are smaller than i32.
 define <2 x double> @combine_psadbw_sitofp_knownbits(<16 x i8> %a0) nounwind {
 ; X86-SSE-LABEL: combine_psadbw_sitofp_knownbits:
 ; X86-SSE:       # %bb.0:
-; X86-SSE-NEXT:    pushl %ebp
-; X86-SSE-NEXT:    movl %esp, %ebp
-; X86-SSE-NEXT:    andl $-8, %esp
-; X86-SSE-NEXT:    subl $32, %esp
 ; X86-SSE-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
 ; X86-SSE-NEXT:    pxor %xmm1, %xmm1
 ; X86-SSE-NEXT:    psadbw %xmm0, %xmm1
-; X86-SSE-NEXT:    movq %xmm1, {{[0-9]+}}(%esp)
-; X86-SSE-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[2,3,2,3]
-; X86-SSE-NEXT:    movq %xmm0, {{[0-9]+}}(%esp)
-; X86-SSE-NEXT:    fildll {{[0-9]+}}(%esp)
-; X86-SSE-NEXT:    fstpl {{[0-9]+}}(%esp)
-; X86-SSE-NEXT:    fildll {{[0-9]+}}(%esp)
-; X86-SSE-NEXT:    fstpl (%esp)
-; X86-SSE-NEXT:    movsd {{.*#+}} xmm0 = mem[0],zero
-; X86-SSE-NEXT:    movhps {{.*#+}} xmm0 = xmm0[0,1],mem[0,1]
-; X86-SSE-NEXT:    movl %ebp, %esp
-; X86-SSE-NEXT:    popl %ebp
+; X86-SSE-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[0,2,2,3]
+; X86-SSE-NEXT:    cvtdq2pd %xmm0, %xmm0
 ; X86-SSE-NEXT:    retl
 ;
 ; X64-SSE-LABEL: combine_psadbw_sitofp_knownbits:
@@ -111,14 +88,8 @@ define <2 x double> @combine_psadbw_sitofp_knownbits(<16 x i8> %a0) nounwind {
 ; X64-SSE-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
 ; X64-SSE-NEXT:    pxor %xmm1, %xmm1
 ; X64-SSE-NEXT:    psadbw %xmm0, %xmm1
-; X64-SSE-NEXT:    movd %xmm1, %eax
-; X64-SSE-NEXT:    xorps %xmm0, %xmm0
-; X64-SSE-NEXT:    cvtsi2sd %eax, %xmm0
-; X64-SSE-NEXT:    pshufd {{.*#+}} xmm1 = xmm1[2,3,2,3]
-; X64-SSE-NEXT:    movd %xmm1, %eax
-; X64-SSE-NEXT:    xorps %xmm1, %xmm1
-; X64-SSE-NEXT:    cvtsi2sd %eax, %xmm1
-; X64-SSE-NEXT:    unpcklpd {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; X64-SSE-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[0,2,2,3]
+; X64-SSE-NEXT:    cvtdq2pd %xmm0, %xmm0
 ; X64-SSE-NEXT:    retq
 ;
 ; AVX2-LABEL: combine_psadbw_sitofp_knownbits:
@@ -126,10 +97,8 @@ define <2 x double> @combine_psadbw_sitofp_knownbits(<16 x i8> %a0) nounwind {
 ; AVX2-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
 ; AVX2-NEXT:    vpsadbw %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vcvtdq2pd %xmm0, %xmm1
-; AVX2-NEXT:    vpextrq $1, %xmm0, %rax
-; AVX2-NEXT:    vcvtsi2sd %eax, %xmm2, %xmm0
-; AVX2-NEXT:    vunpcklpd {{.*#+}} xmm0 = xmm1[0],xmm0[0]
+; AVX2-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
+; AVX2-NEXT:    vcvtdq2pd %xmm0, %xmm0
 ; AVX2-NEXT:    retq
   %mask = and <16 x i8> %a0, <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>
   %sad = tail call <2 x i64> @llvm.x86.sse2.psad.bw(<16 x i8> %mask, <16 x i8> zeroinitializer)
@@ -137,28 +106,24 @@ define <2 x double> @combine_psadbw_sitofp_knownbits(<16 x i8> %a0) nounwind {
   ret <2 x double> %cvt
 }
 
-; TODO: Convert from uitofp to sitofp as the PSADBW results are zero-extended.
+; Convert from uitofp to sitofp as the PSADBW results are zero-extended.
 define <2 x double> @combine_psadbw_uitofp_knownbits(<16 x i8> %a0) nounwind {
 ; X86-SSE-LABEL: combine_psadbw_uitofp_knownbits:
 ; X86-SSE:       # %bb.0:
 ; X86-SSE-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
 ; X86-SSE-NEXT:    pxor %xmm1, %xmm1
-; X86-SSE-NEXT:    psadbw %xmm1, %xmm0
-; X86-SSE-NEXT:    por {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0
-; X86-SSE-NEXT:    movapd {{.*#+}} xmm1 = [0,1160773632,0,1160773632]
-; X86-SSE-NEXT:    subpd {{\.?LCPI[0-9]+_[0-9]+}}, %xmm1
-; X86-SSE-NEXT:    addpd %xmm1, %xmm0
+; X86-SSE-NEXT:    psadbw %xmm0, %xmm1
+; X86-SSE-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[0,2,2,3]
+; X86-SSE-NEXT:    cvtdq2pd %xmm0, %xmm0
 ; X86-SSE-NEXT:    retl
 ;
 ; X64-SSE-LABEL: combine_psadbw_uitofp_knownbits:
 ; X64-SSE:       # %bb.0:
 ; X64-SSE-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
 ; X64-SSE-NEXT:    pxor %xmm1, %xmm1
-; X64-SSE-NEXT:    psadbw %xmm1, %xmm0
-; X64-SSE-NEXT:    por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; X64-SSE-NEXT:    movapd {{.*#+}} xmm1 = [4985484787499139072,4985484787499139072]
-; X64-SSE-NEXT:    subpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
-; X64-SSE-NEXT:    addpd %xmm1, %xmm0
+; X64-SSE-NEXT:    psadbw %xmm0, %xmm1
+; X64-SSE-NEXT:    pshufd {{.*#+}} xmm0 = xmm1[0,2,2,3]
+; X64-SSE-NEXT:    cvtdq2pd %xmm0, %xmm0
 ; X64-SSE-NEXT:    retq
 ;
 ; AVX2-LABEL: combine_psadbw_uitofp_knownbits:
@@ -166,12 +131,8 @@ define <2 x double> @combine_psadbw_uitofp_knownbits(<16 x i8> %a0) nounwind {
 ; AVX2-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
 ; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
 ; AVX2-NEXT:    vpsadbw %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vpblendd {{.*#+}} xmm0 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
-; AVX2-NEXT:    vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; AVX2-NEXT:    vmovddup {{.*#+}} xmm1 = [4985484787499139072,4985484787499139072]
-; AVX2-NEXT:    # xmm1 = mem[0,0]
-; AVX2-NEXT:    vsubpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX2-NEXT:    vaddpd %xmm1, %xmm0, %xmm0
+; AVX2-NEXT:    vpshufd {{.*#+}} xmm0 = xmm0[0,2,2,3]
+; AVX2-NEXT:    vcvtdq2pd %xmm0, %xmm0
 ; AVX2-NEXT:    retq
   %mask = and <16 x i8> %a0, <i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1, i8 1>
   %sad = tail call <2 x i64> @llvm.x86.sse2.psad.bw(<16 x i8> %mask, <16 x i8> zeroinitializer)
diff --git a/llvm/test/CodeGen/X86/sad.ll b/llvm/test/CodeGen/X86/sad.ll
index 2a33e75a8357c6..ca319687da54d4 100644
--- a/llvm/test/CodeGen/X86/sad.ll
+++ b/llvm/test/CodeGen/X86/sad.ll
@@ -989,9 +989,7 @@ define dso_local i32 @sad_unroll_nonzero_initial(ptr %arg, ptr %arg1, ptr %arg2,
 ; SSE2-NEXT:    paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm2[2,3,2,3]
 ; SSE2-NEXT:    paddd %xmm2, %xmm0
-; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
-; SSE2-NEXT:    paddd %xmm0, %xmm1
-; SSE2-NEXT:    movd %xmm1, %eax
+; SSE2-NEXT:    movd %xmm0, %eax
 ; SSE2-NEXT:    retq
 ;
 ; AVX-LABEL: sad_unroll_nonzero_initial:
@@ -1053,9 +1051,7 @@ define dso_local i32 @sad_double_reduction(ptr %arg, ptr %arg1, ptr %arg2, ptr %
 ; SSE2-NEXT:    paddd %xmm1, %xmm2
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm2[2,3,2,3]
 ; SSE2-NEXT:    paddd %xmm2, %xmm0
-; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
-; SSE2-NEXT:    por %xmm0, %xmm1
-; SSE2-NEXT:    movd %xmm1, %eax
+; SSE2-NEXT:    movd %xmm0, %eax
 ; SSE2-NEXT:    retq
 ;
 ; AVX-LABEL: sad_double_reduction:
@@ -1067,8 +1063,6 @@ define dso_local i32 @sad_double_reduction(ptr %arg, ptr %arg1, ptr %arg2, ptr %
 ; AVX-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
 ; AVX-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
 ; AVX-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
-; AVX-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
-; AVX-NEXT:    vpor %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vmovd %xmm0, %eax
 ; AVX-NEXT:    retq
 bb:
@@ -1115,9 +1109,7 @@ define dso_local i32 @sad_double_reduction_abs(ptr %arg, ptr %arg1, ptr %arg2, p
 ; SSE2-NEXT:    paddd %xmm1, %xmm2
 ; SSE2-NEXT:    pshufd {{.*#+}} xmm0 = xmm2[2,3,2,3]
 ; SSE2-NEXT:    paddd %xmm2, %xmm0
-; SSE2-NEXT:    pshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
-; SSE2-NEXT:    por %xmm0, %xmm1
-; SSE2-NEXT:    movd %xmm1, %eax
+; SSE2-NEXT:    movd %xmm0, %eax
 ; SSE2-NEXT:    retq
 ;
 ; AVX-LABEL: sad_double_reduction_abs:
@@ -1129,8 +1121,6 @@ define dso_local i32 @sad_double_reduction_abs(ptr %arg, ptr %arg1, ptr %arg2, p
 ; AVX-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
 ; AVX-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,2,3]
 ; AVX-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
-; AVX-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[1,1,1,1]
-; AVX-NEXT:    vpor %xmm1, %xmm0, %xmm0
 ; AVX-NEXT:    vmovd %xmm0, %eax
 ; AVX-NEXT:    retq
 bb:



More information about the llvm-commits mailing list