[llvm] [X86] Improve KnownBits for X86ISD::PSADBW nodes (PR #83830)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 4 06:41:55 PST 2024


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/83830

>From cf8f19bcee6c3f4b5a90f4d49ac1feba104ffdfa Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Fri, 1 Mar 2024 11:00:09 +0000
Subject: [PATCH] [X86] Improve KnownBits for X86ISD::PSADBW nodes

Don't just return the known zero upperbits, compute the absdiff Knownbits and perform the horizontal sum
---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 26 +++++++++++++++++++++----
 llvm/test/CodeGen/X86/psadbw.ll         |  5 +----
 2 files changed, 23 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b87e3121838dcc..77439170254e3c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -36836,12 +36836,24 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     break;
   }
   case X86ISD::PSADBW: {
+    SDValue LHS = Op.getOperand(0);
+    SDValue RHS = Op.getOperand(1);
     assert(VT.getScalarType() == MVT::i64 &&
-           Op.getOperand(0).getValueType().getScalarType() == MVT::i8 &&
+           LHS.getValueType() == RHS.getValueType() &&
+           LHS.getValueType().getScalarType() == MVT::i8 &&
            "Unexpected PSADBW types");
 
-    // PSADBW - fills low 16 bits and zeros upper 48 bits of each i64 result.
-    Known.Zero.setBitsFrom(16);
+    KnownBits Known2;
+    unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+    APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+    Known = DAG.computeKnownBits(RHS, DemandedSrcElts, Depth + 1);
+    Known2 = DAG.computeKnownBits(LHS, DemandedSrcElts, Depth + 1);
+    Known = KnownBits::absdiff(Known, Known2).zext(16);
+    // Known = (((D0 + D1) + (D2 + D3)) + ((D4 + D5) + (D6 + D7)))
+    Known = KnownBits::computeForAddSub(true, true, Known, Known);
+    Known = KnownBits::computeForAddSub(true, true, Known, Known);
+    Known = KnownBits::computeForAddSub(true, true, Known, Known);
+    Known = Known.zext(64);
     break;
   }
   case X86ISD::PCMPGT:
@@ -54853,6 +54865,7 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
 }
 
 static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
+                                    TargetLowering::DAGCombinerInfo &DCI,
                                     const X86Subtarget &Subtarget) {
   MVT VT = N->getSimpleValueType(0);
   SDLoc DL(N);
@@ -54864,6 +54877,11 @@ static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
       return DAG.getConstant(0, DL, VT);
   }
 
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (TLI.SimplifyDemandedBits(
+          SDValue(N, 0), APInt::getAllOnes(VT.getScalarSizeInBits()), DCI))
+    return SDValue(N, 0);
+
   return SDValue();
 }
 
@@ -56587,7 +56605,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::MGATHER:
   case ISD::MSCATTER:       return combineGatherScatter(N, DAG, DCI);
   case X86ISD::PCMPEQ:
-  case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, Subtarget);
+  case X86ISD::PCMPGT:      return combineVectorCompare(N, DAG, DCI, Subtarget);
   case X86ISD::PMULDQ:
   case X86ISD::PMULUDQ:     return combinePMULDQ(N, DAG, DCI, Subtarget);
   case X86ISD::VPMADDUBSW:
diff --git a/llvm/test/CodeGen/X86/psadbw.ll b/llvm/test/CodeGen/X86/psadbw.ll
index 8141b22d321f4d..8044472b13e3a8 100644
--- a/llvm/test/CodeGen/X86/psadbw.ll
+++ b/llvm/test/CodeGen/X86/psadbw.ll
@@ -70,10 +70,7 @@ define <2 x i64> @combine_psadbw_cmp_knownbits(<16 x i8> %a0) nounwind {
 ;
 ; AVX2-LABEL: combine_psadbw_cmp_knownbits:
 ; AVX2:       # %bb.0:
-; AVX2-NEXT:    vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; AVX2-NEXT:    vpsadbw %xmm1, %xmm0, %xmm0
-; AVX2-NEXT:    vpcmpgtq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX2-NEXT:    vxorps %xmm0, %xmm0, %xmm0
 ; AVX2-NEXT:    retq
   %mask = and <16 x i8> %a0, <i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3>
   %sad = tail call <2 x i64> @llvm.x86.sse2.psad.bw(<16 x i8> %mask, <16 x i8> zeroinitializer)



More information about the llvm-commits mailing list