[llvm] 17c567b - [X86] combineVPMADD - add constant folding support for PMADDWD/PMADDUBSW instructions
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 28 05:27:38 PDT 2024
Author: Simon Pilgrim
Date: 2024-06-28T13:26:43+01:00
New Revision: 17c567b095ab749b59d311ec9d8cd2bae584ac0b
URL: https://github.com/llvm/llvm-project/commit/17c567b095ab749b59d311ec9d8cd2bae584ac0b
DIFF: https://github.com/llvm/llvm-project/commit/17c567b095ab749b59d311ec9d8cd2bae584ac0b.diff
LOG: [X86] combineVPMADD - add constant folding support for PMADDWD/PMADDUBSW instructions
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/combine-pmadd.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 767c58270a4dc..348c2b56e6e3c 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -56922,9 +56922,13 @@ static SDValue combinePMULDQ(SDNode *N, SelectionDAG &DAG,
// Simplify VPMADDUBSW/VPMADDWD operations.
static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
- EVT VT = N->getValueType(0);
+ MVT VT = N->getSimpleValueType(0);
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
+ unsigned Opc = N->getOpcode();
+ bool IsPMADDWD = Opc == X86ISD::VPMADDWD;
+ assert((Opc == X86ISD::VPMADDWD || Opc == X86ISD::VPMADDUBSW) &&
+ "Unexpected PMADD opcode");
// Multiply by zero.
// Don't return LHS/RHS as it may contain UNDEFs.
@@ -56932,6 +56936,27 @@ static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
ISD::isBuildVectorAllZeros(RHS.getNode()))
return DAG.getConstant(0, SDLoc(N), VT);
+ // Constant folding.
+ APInt LHSUndefs, RHSUndefs;
+ SmallVector<APInt> LHSBits, RHSBits;
+ unsigned SrcEltBits = LHS.getScalarValueSizeInBits();
+ unsigned DstEltBits = VT.getScalarSizeInBits();
+ if (getTargetConstantBitsFromNode(LHS, SrcEltBits, LHSUndefs, LHSBits) &&
+ getTargetConstantBitsFromNode(RHS, SrcEltBits, RHSUndefs, RHSBits)) {
+ SmallVector<APInt> Result;
+ for (unsigned I = 0, E = LHSBits.size(); I != E; I += 2) {
+ APInt LHSLo = LHSBits[I + 0], LHSHi = LHSBits[I + 1];
+ APInt RHSLo = RHSBits[I + 0], RHSHi = RHSBits[I + 1];
+ LHSLo = IsPMADDWD ? LHSLo.sext(DstEltBits) : LHSLo.zext(DstEltBits);
+ LHSHi = IsPMADDWD ? LHSHi.sext(DstEltBits) : LHSHi.zext(DstEltBits);
+ APInt Lo = LHSLo * RHSLo.sext(DstEltBits);
+ APInt Hi = LHSHi * RHSHi.sext(DstEltBits);
+ APInt Res = IsPMADDWD ? (Lo + Hi) : Lo.sadd_sat(Hi);
+ Result.push_back(Res);
+ }
+ return getConstVector(Result, VT, DAG, SDLoc(N));
+ }
+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
APInt DemandedElts = APInt::getAllOnes(VT.getVectorNumElements());
if (TLI.SimplifyDemandedVectorElts(SDValue(N, 0), DemandedElts, DCI))
diff --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index c4c3e2870145e..c5f9c8506f291 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -88,43 +88,33 @@ define <4 x i32> @combine_pmaddwd_demandedelts(<8 x i16> %a0, <8 x i16> %a1) {
ret <4 x i32> %4
}
-; TODO: [2] = (-5*13)+(6*-15) = -155 = 4294967141
+; [2]: (-5*13)+(6*-15) = -155 = 4294967141
define <4 x i32> @combine_pmaddwd_constant() {
; SSE-LABEL: combine_pmaddwd_constant:
; SSE: # %bb.0:
-; SSE-NEXT: pmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
-; SSE-NEXT: pmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
+; SSE-NEXT: movaps {{.*#+}} xmm0 = [19,17,4294967141,271]
; SSE-NEXT: retq
;
; AVX-LABEL: combine_pmaddwd_constant:
; AVX: # %bb.0:
-; AVX-NEXT: vpmovsxbw {{.*#+}} xmm0 = [65535,2,3,65532,65531,6,7,65528]
-; AVX-NEXT: vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [65531,7,65527,65525,13,65521,17,65517]
+; AVX-NEXT: vmovaps {{.*#+}} xmm0 = [19,17,4294967141,271]
; AVX-NEXT: retq
%1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -9, i16 -11, i16 13, i16 -15, i16 17, i16 -19>)
ret <4 x i32> %1
}
; ensure we don't assume pmaddwd performs add nsw
-; TODO: (-32768*-32768)+(-32768*-32768) = 0x80000000 = 2147483648
+; [0]: (-32768*-32768)+(-32768*-32768) = 0x80000000 = 2147483648
define <4 x i32> @combine_pmaddwd_constant_nsw() {
; SSE-LABEL: combine_pmaddwd_constant_nsw:
; SSE: # %bb.0:
-; SSE-NEXT: movdqa {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
-; SSE-NEXT: pmaddwd %xmm0, %xmm0
+; SSE-NEXT: movaps {{.*#+}} xmm0 = [2147483648,2147483648,2147483648,2147483648]
; SSE-NEXT: retq
;
-; AVX1-LABEL: combine_pmaddwd_constant_nsw:
-; AVX1: # %bb.0:
-; AVX1-NEXT: vbroadcastss {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
-; AVX1-NEXT: vpmaddwd %xmm0, %xmm0, %xmm0
-; AVX1-NEXT: retq
-;
-; AVX2-LABEL: combine_pmaddwd_constant_nsw:
-; AVX2: # %bb.0:
-; AVX2-NEXT: vpbroadcastw {{.*#+}} xmm0 = [32768,32768,32768,32768,32768,32768,32768,32768]
-; AVX2-NEXT: vpmaddwd %xmm0, %xmm0, %xmm0
-; AVX2-NEXT: retq
+; AVX-LABEL: combine_pmaddwd_constant_nsw:
+; AVX: # %bb.0:
+; AVX-NEXT: vbroadcastss {{.*#+}} xmm0 = [2147483648,2147483648,2147483648,2147483648]
+; AVX-NEXT: retq
%1 = insertelement <8 x i16> undef, i16 32768, i32 0
%2 = shufflevector <8 x i16> %1, <8 x i16> undef, <8 x i32> zeroinitializer
%3 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %2, <8 x i16> %2)
@@ -213,51 +203,26 @@ define <8 x i16> @combine_pmaddubsw_demandedelts(<16 x i8> %a0, <16 x i8> %a1) {
ret <8 x i16> %4
}
-; TODO
+; [3]: ((uint16_t)-6*7)+(7*-8) = (250*7)+(7*-8) = 1694
define i32 @combine_pmaddubsw_constant() {
-; SSE-LABEL: combine_pmaddubsw_constant:
-; SSE: # %bb.0:
-; SSE-NEXT: movdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; SSE-NEXT: pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; SSE-NEXT: pextrw $3, %xmm0, %eax
-; SSE-NEXT: cwtl
-; SSE-NEXT: retq
-;
-; AVX-LABEL: combine_pmaddubsw_constant:
-; AVX: # %bb.0:
-; AVX-NEXT: vmovdqa {{.*#+}} xmm0 = [0,1,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; AVX-NEXT: vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [1,2,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; AVX-NEXT: vpextrw $3, %xmm0, %eax
-; AVX-NEXT: cwtl
-; AVX-NEXT: retq
+; CHECK-LABEL: combine_pmaddubsw_constant:
+; CHECK: # %bb.0:
+; CHECK-NEXT: movl $1694, %eax # imm = 0x69E
+; CHECK-NEXT: retq
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
- %2 = extractelement <8 x i16> %1, i32 3 ; ((uint16_t)-6*7)+(7*-8) = (250*7)+(7*-8) = 1694
+ %2 = extractelement <8 x i16> %1, i32 3
%3 = sext i16 %2 to i32
ret i32 %3
}
-; TODO
+; [0]: add_sat_i16(((uint16_t)-1*-128),((uint16_t)-1*-128)_ = add_sat_i16(255*-128),(255*-128)) = sat_i16(-65280) = -32768
define i32 @combine_pmaddubsw_constant_sat() {
-; SSE-LABEL: combine_pmaddubsw_constant_sat:
-; SSE: # %bb.0:
-; SSE-NEXT: movdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; SSE-NEXT: pmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; SSE-NEXT: movd %xmm0, %eax
-; SSE-NEXT: cwtl
-; SSE-NEXT: retq
-;
-; AVX-LABEL: combine_pmaddubsw_constant_sat:
-; AVX: # %bb.0:
-; AVX-NEXT: vmovdqa {{.*#+}} xmm0 = [255,255,2,3,4,5,250,7,8,9,10,11,12,13,14,15]
-; AVX-NEXT: vpmaddubsw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0 # [128,128,3,4,5,6,7,248,9,10,11,12,13,14,15,16]
-; AVX-NEXT: vmovd %xmm0, %eax
-; AVX-NEXT: cwtl
-; AVX-NEXT: retq
+; CHECK-LABEL: combine_pmaddubsw_constant_sat:
+; CHECK: # %bb.0:
+; CHECK-NEXT: movl $-32768, %eax # imm = 0x8000
+; CHECK-NEXT: retq
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
- %2 = extractelement <8 x i16> %1, i32 0 ; add_sat_i16(((uint16_t)-1*-128),((uint16_t)-1*-128)_ = add_sat_i16(255*-128),(255*-128)) = sat_i16(-65280) = -32768
+ %2 = extractelement <8 x i16> %1, i32 0
%3 = sext i16 %2 to i32
ret i32 %3
}
-
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; CHECK: {{.*}}
More information about the llvm-commits
mailing list