[llvm] [X86] Improve KnownBits for X86ISD::PSADBW nodes (PR #83830)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 4 03:42:47 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-x86
Author: Simon Pilgrim (RKSimon)
<details>
<summary>Changes</summary>
Don't just return the known zero upperbits, compute the absdiff Knownbits and perform the horizontal sum
---
Full diff: https://github.com/llvm/llvm-project/pull/83830.diff
2 Files Affected:
- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+21-4)
- (modified) llvm/test/CodeGen/X86/psadbw.ll (+1-4)
``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b87e3121838dcc..5076ac5e347e9f 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -36836,12 +36836,23 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
break;
}
case X86ISD::PSADBW: {
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
assert(VT.getScalarType() == MVT::i64 &&
- Op.getOperand(0).getValueType().getScalarType() == MVT::i8 &&
+ LHS.getValueType() == RHS.getValueType() &&
+ LHS.getValueType().getScalarType() == MVT::i8 &&
"Unexpected PSADBW types");
- // PSADBW - fills low 16 bits and zeros upper 48 bits of each i64 result.
- Known.Zero.setBitsFrom(16);
+ KnownBits Known2;
+ unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+ APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+ Known = DAG.computeKnownBits(RHS, DemandedSrcElts, Depth + 1);
+ Known2 = DAG.computeKnownBits(LHS, DemandedSrcElts, Depth + 1);
+ Known = KnownBits::absdiff(Known, Known2).zext(16);
+ Known = KnownBits::computeForAddSub(true, true, Known, Known);
+ Known = KnownBits::computeForAddSub(true, true, Known, Known);
+ Known = KnownBits::computeForAddSub(true, true, Known, Known);
+ Known = Known.zext(64);
break;
}
case X86ISD::PCMPGT:
@@ -54853,6 +54864,7 @@ static SDValue combineSub(SDNode *N, SelectionDAG &DAG,
}
static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
MVT VT = N->getSimpleValueType(0);
SDLoc DL(N);
@@ -54864,6 +54876,11 @@ static SDValue combineVectorCompare(SDNode *N, SelectionDAG &DAG,
return DAG.getConstant(0, DL, VT);
}
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (TLI.SimplifyDemandedBits(
+ SDValue(N, 0), APInt::getAllOnes(VT.getScalarSizeInBits()), DCI))
+ return SDValue(N, 0);
+
return SDValue();
}
@@ -56587,7 +56604,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::MGATHER:
case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI);
case X86ISD::PCMPEQ:
- case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, Subtarget);
+ case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, DCI, Subtarget);
case X86ISD::PMULDQ:
case X86ISD::PMULUDQ: return combinePMULDQ(N, DAG, DCI, Subtarget);
case X86ISD::VPMADDUBSW:
diff --git a/llvm/test/CodeGen/X86/psadbw.ll b/llvm/test/CodeGen/X86/psadbw.ll
index 8141b22d321f4d..8044472b13e3a8 100644
--- a/llvm/test/CodeGen/X86/psadbw.ll
+++ b/llvm/test/CodeGen/X86/psadbw.ll
@@ -70,10 +70,7 @@ define <2 x i64> @combine_psadbw_cmp_knownbits(<16 x i8> %a0) nounwind {
;
; AVX2-LABEL: combine_psadbw_cmp_knownbits:
; AVX2: # %bb.0:
-; AVX2-NEXT: vpand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
-; AVX2-NEXT: vpsadbw %xmm1, %xmm0, %xmm0
-; AVX2-NEXT: vpcmpgtq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; AVX2-NEXT: vxorps %xmm0, %xmm0, %xmm0
; AVX2-NEXT: retq
%mask = and <16 x i8> %a0, <i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3, i8 3>
%sad = tail call <2 x i64> @llvm.x86.sse2.psad.bw(<16 x i8> %mask, <16 x i8> zeroinitializer)
``````````
</details>
https://github.com/llvm/llvm-project/pull/83830
More information about the llvm-commits
mailing list