[llvm] ca4b1f8 - [X86] computeKnownBitsForTargetNode - add handling for PMADDWD/PMADDUBSW nodes
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 28 08:29:32 PDT 2024
Author: Simon Pilgrim
Date: 2024-06-28T16:29:14+01:00
New Revision: ca4b1f8629d3162a572a0888322789c56bb75921
URL: https://github.com/llvm/llvm-project/commit/ca4b1f8629d3162a572a0888322789c56bb75921
DIFF: https://github.com/llvm/llvm-project/commit/ca4b1f8629d3162a572a0888322789c56bb75921.diff
LOG: [X86] computeKnownBitsForTargetNode - add handling for PMADDWD/PMADDUBSW nodes
These were reverted in fa0e9acea5e4d363eef6acc484afc1b22ab8e698 while we triaged an infinite loop regression
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/combine-pmadd.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 0cd5b09b10ba8..50cacc3038b0c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -37139,6 +37139,52 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
Known = Known.zext(64);
}
+static void computeKnownBitsForPMADDWD(SDValue LHS, SDValue RHS,
+ KnownBits &Known,
+ const APInt &DemandedElts,
+ const SelectionDAG &DAG,
+ unsigned Depth) {
+ unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+
+ // Multiply signed i16 elements to create i32 values and add Lo/Hi pairs.
+ APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+ APInt DemandedLoElts =
+ DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
+ APInt DemandedHiElts =
+ DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
+ KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
+ KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
+ KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
+ KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
+ KnownBits Lo = KnownBits::mul(LHSLo.sext(32), RHSLo.sext(32));
+ KnownBits Hi = KnownBits::mul(LHSHi.sext(32), RHSHi.sext(32));
+ Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/false,
+ /*NUW=*/false, Lo, Hi);
+}
+
+static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
+ KnownBits &Known,
+ const APInt &DemandedElts,
+ const SelectionDAG &DAG,
+ unsigned Depth) {
+ unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+
+ // Multiply unsigned/signed i8 elements to create i16 values and add_sat Lo/Hi
+ // pairs.
+ APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+ APInt DemandedLoElts =
+ DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
+ APInt DemandedHiElts =
+ DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
+ KnownBits LHSLo = DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1);
+ KnownBits LHSHi = DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1);
+ KnownBits RHSLo = DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1);
+ KnownBits RHSHi = DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1);
+ KnownBits Lo = KnownBits::mul(LHSLo.zext(16), RHSLo.sext(16));
+ KnownBits Hi = KnownBits::mul(LHSHi.zext(16), RHSHi.sext(16));
+ Known = KnownBits::sadd_sat(Lo, Hi);
+}
+
void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
KnownBits &Known,
const APInt &DemandedElts,
@@ -37314,6 +37360,26 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
}
break;
}
+ case X86ISD::VPMADDWD: {
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
+ assert(VT.getVectorElementType() == MVT::i32 &&
+ LHS.getValueType() == RHS.getValueType() &&
+ LHS.getValueType().getVectorElementType() == MVT::i16 &&
+ "Unexpected PMADDWD types");
+ computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
+ break;
+ }
+ case X86ISD::VPMADDUBSW: {
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
+ assert(VT.getVectorElementType() == MVT::i16 &&
+ LHS.getValueType() == RHS.getValueType() &&
+ LHS.getValueType().getVectorElementType() == MVT::i8 &&
+ "Unexpected PMADDUBSW types");
+ computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
+ break;
+ }
case X86ISD::PMULUDQ: {
KnownBits Known2;
Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
@@ -37450,6 +37516,30 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
}
case ISD::INTRINSIC_WO_CHAIN: {
switch (Op->getConstantOperandVal(0)) {
+ case Intrinsic::x86_sse2_pmadd_wd:
+ case Intrinsic::x86_avx2_pmadd_wd:
+ case Intrinsic::x86_avx512_pmaddw_d_512: {
+ SDValue LHS = Op.getOperand(1);
+ SDValue RHS = Op.getOperand(2);
+ assert(VT.getScalarType() == MVT::i32 &&
+ LHS.getValueType() == RHS.getValueType() &&
+ LHS.getValueType().getScalarType() == MVT::i16 &&
+ "Unexpected PMADDWD types");
+ computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
+ break;
+ }
+ case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
+ case Intrinsic::x86_avx2_pmadd_ub_sw:
+ case Intrinsic::x86_avx512_pmaddubs_w_512: {
+ SDValue LHS = Op.getOperand(1);
+ SDValue RHS = Op.getOperand(2);
+ assert(VT.getScalarType() == MVT::i16 &&
+ LHS.getValueType() == RHS.getValueType() &&
+ LHS.getValueType().getScalarType() == MVT::i8 &&
+ "Unexpected PMADDUBSW types");
+ computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
+ break;
+ }
case Intrinsic::x86_sse2_psad_bw:
case Intrinsic::x86_avx2_psad_bw:
case Intrinsic::x86_avx512_psad_bw_512: {
diff --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index 83cce1061bc1e..565d9ef6eb3e6 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -229,35 +229,10 @@ define i32 @combine_pmaddubsw_constant_sat() {
; Constant folding PMADDWD was causing an infinite loop in the PCMPGT commuting between 2 constant values.
define i1 @pmaddwd_pcmpgt_infinite_loop() {
-; SSE-LABEL: pmaddwd_pcmpgt_infinite_loop:
-; SSE: # %bb.0:
-; SSE-NEXT: movdqa {{.*#+}} xmm0 = [2147483647,2147483647,2147483647,2147483647]
-; SSE-NEXT: paddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; SSE-NEXT: pcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; SSE-NEXT: movmskps %xmm0, %eax
-; SSE-NEXT: testl %eax, %eax
-; SSE-NEXT: sete %al
-; SSE-NEXT: retq
-;
-; AVX1-LABEL: pmaddwd_pcmpgt_infinite_loop:
-; AVX1: # %bb.0:
-; AVX1-NEXT: vpcmpeqd %xmm0, %xmm0, %xmm0
-; AVX1-NEXT: vbroadcastss {{.*#+}} xmm1 = [2147483647,2147483647,2147483647,2147483647]
-; AVX1-NEXT: vpaddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX1-NEXT: vpcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX1-NEXT: vtestps %xmm1, %xmm0
-; AVX1-NEXT: sete %al
-; AVX1-NEXT: retq
-;
-; AVX2-LABEL: pmaddwd_pcmpgt_infinite_loop:
-; AVX2: # %bb.0:
-; AVX2-NEXT: vpcmpeqd %xmm0, %xmm0, %xmm0
-; AVX2-NEXT: vpbroadcastd {{.*#+}} xmm1 = [2147483647,2147483647,2147483647,2147483647]
-; AVX2-NEXT: vpaddd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX2-NEXT: vpcmpgtd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; AVX2-NEXT: vtestps %xmm1, %xmm0
-; AVX2-NEXT: sete %al
-; AVX2-NEXT: retq
+; CHECK-LABEL: pmaddwd_pcmpgt_infinite_loop:
+; CHECK: # %bb.0:
+; CHECK-NEXT: movb $1, %al
+; CHECK-NEXT: retq
%1 = tail call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>, <8 x i16> <i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768, i16 -32768>)
%2 = icmp eq <4 x i32> %1, <i32 -2147483648, i32 -2147483648, i32 -2147483648, i32 -2147483648>
%3 = select <4 x i1> %2, <4 x i32> <i32 2147483647, i32 2147483647, i32 2147483647, i32 2147483647>, <4 x i32> zeroinitializer
More information about the llvm-commits
mailing list