[llvm] 0938cdb - [X86] computeKnownBitsForTargetNode - add handling for (V)PMADDUBSW nodes
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Sat Jun 15 10:32:20 PDT 2024
Author: Simon Pilgrim
Date: 2024-06-15T18:25:11+01:00
New Revision: 0938cdbfbd946e4033b070ec4fbd9dfd790eb91a
URL: https://github.com/llvm/llvm-project/commit/0938cdbfbd946e4033b070ec4fbd9dfd790eb91a
DIFF: https://github.com/llvm/llvm-project/commit/0938cdbfbd946e4033b070ec4fbd9dfd790eb91a.diff
LOG: [X86] computeKnownBitsForTargetNode - add handling for (V)PMADDUBSW nodes
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/combine-pmadd.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6aa1a5b52bb67..f27c935812f51 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -37109,6 +37109,33 @@ static void computeKnownBitsForPMADDWD(SDValue LHS, SDValue RHS,
/*NUW=*/false, Lo, Hi);
}
+static void computeKnownBitsForPMADDUBSW(SDValue LHS, SDValue RHS,
+ KnownBits &Known,
+ const APInt &DemandedElts,
+ const SelectionDAG &DAG,
+ unsigned Depth) {
+ unsigned NumSrcElts = LHS.getValueType().getVectorNumElements();
+
+ // Multiply signed/unsigned i8 elements to create i16 values and add_sat Lo/Hi
+ // pairs.
+ APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
+ APInt DemandedLoElts =
+ DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b01));
+ APInt DemandedHiElts =
+ DemandedSrcElts & APInt::getSplat(NumSrcElts, APInt(2, 0b10));
+ KnownBits LHSLo =
+ DAG.computeKnownBits(LHS, DemandedLoElts, Depth + 1).zext(16);
+ KnownBits LHSHi =
+ DAG.computeKnownBits(LHS, DemandedHiElts, Depth + 1).zext(16);
+ KnownBits RHSLo =
+ DAG.computeKnownBits(RHS, DemandedLoElts, Depth + 1).sext(16);
+ KnownBits RHSHi =
+ DAG.computeKnownBits(RHS, DemandedHiElts, Depth + 1).sext(16);
+ KnownBits Lo = KnownBits::mul(LHSLo, RHSLo);
+ KnownBits Hi = KnownBits::mul(LHSHi, RHSHi);
+ Known = KnownBits::sadd_sat(Lo, Hi);
+}
+
void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
KnownBits &Known,
const APInt &DemandedElts,
@@ -37294,6 +37321,16 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
break;
}
+ case X86ISD::VPMADDUBSW: {
+ SDValue LHS = Op.getOperand(0);
+ SDValue RHS = Op.getOperand(1);
+ assert(VT.getVectorElementType() == MVT::i16 &&
+ LHS.getValueType() == RHS.getValueType() &&
+ LHS.getValueType().getVectorElementType() == MVT::i8 &&
+ "Unexpected PMADDUBSW types");
+ computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
+ break;
+ }
case X86ISD::PMULUDQ: {
KnownBits Known2;
Known = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
@@ -37442,6 +37479,18 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
computeKnownBitsForPMADDWD(LHS, RHS, Known, DemandedElts, DAG, Depth);
break;
}
+ case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
+ case Intrinsic::x86_avx2_pmadd_ub_sw:
+ case Intrinsic::x86_avx512_pmaddubs_w_512: {
+ SDValue LHS = Op.getOperand(1);
+ SDValue RHS = Op.getOperand(2);
+ assert(VT.getScalarType() == MVT::i16 &&
+ LHS.getValueType() == RHS.getValueType() &&
+ LHS.getValueType().getScalarType() == MVT::i8 &&
+ "Unexpected PMADDUBSW types");
+ computeKnownBitsForPMADDUBSW(LHS, RHS, Known, DemandedElts, DAG, Depth);
+ break;
+ }
case Intrinsic::x86_sse2_psad_bw:
case Intrinsic::x86_avx2_psad_bw:
case Intrinsic::x86_avx512_psad_bw_512: {
diff --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index 8c8da55503aa1..ad3b9e645f628 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -73,21 +73,10 @@ define <8 x i16> @combine_pmaddubsw_zero_commute(<16 x i8> %a0, <16 x i8> %a1) {
}
define i32 @combine_pmaddubsw_constant() {
-; SSE-LABEL: combine_pmaddubsw_constant:
-; SSE: # %bb.0:
-; SSE-NEXT: movdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; SSE-NEXT: pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; SSE-NEXT: pextrw $3, %xmm0, %eax
-; SSE-NEXT: cwtl
-; SSE-NEXT: retq
-;
-; AVX-LABEL: combine_pmaddubsw_constant:
-; AVX: # %bb.0:
-; AVX-NEXT: vmovdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; AVX-NEXT: vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; AVX-NEXT: vpextrw $3, %xmm0, %eax
-; AVX-NEXT: cwtl
-; AVX-NEXT: retq
+; CHECK-LABEL: combine_pmaddubsw_constant:
+; CHECK: # %bb.0:
+; CHECK-NEXT: movl $1694, %eax # imm = 0x69E
+; CHECK-NEXT: retq
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
%2 = extractelement <8 x i16> %1, i32 3 ; ((uint16_t)-6*7)+(7*-8) = (250*7)+(7*-8) = 1694
%3 = sext i16 %2 to i32
@@ -95,21 +84,10 @@ define i32 @combine_pmaddubsw_constant() {
}
define i32 @combine_pmaddubsw_constant_sat() {
-; SSE-LABEL: combine_pmaddubsw_constant_sat:
-; SSE: # %bb.0:
-; SSE-NEXT: movdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; SSE-NEXT: pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; SSE-NEXT: movd %xmm0, %eax
-; SSE-NEXT: cwtl
-; SSE-NEXT: retq
-;
-; AVX-LABEL: combine_pmaddubsw_constant_sat:
-; AVX: # %bb.0:
-; AVX-NEXT: vmovdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; AVX-NEXT: vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; AVX-NEXT: vmovd %xmm0, %eax
-; AVX-NEXT: cwtl
-; AVX-NEXT: retq
+; CHECK-LABEL: combine_pmaddubsw_constant_sat:
+; CHECK: # %bb.0:
+; CHECK-NEXT: movl $-32768, %eax # imm = 0x8000
+; CHECK-NEXT: retq
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
%2 = extractelement <8 x i16> %1, i32 0 ; add_sat_i16(((uint16_t)-1*-128),((uint16_t)-1*-128)_ = add_sat_i16(255*-128),(255*-128)) = sat_i16(-65280) = -32768
%3 = sext i16 %2 to i32
More information about the llvm-commits
mailing list