[llvm] f1faba2 - [InstCombine][X86] Add constant folding for PMADDWD/PMADDUBSW intrinsics
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 28 03:49:50 PDT 2024
Author: Simon Pilgrim
Date: 2024-06-28T11:49:08+01:00
New Revision: f1faba25433c971f024dd8a29da14020246e89ec
URL: https://github.com/llvm/llvm-project/commit/f1faba25433c971f024dd8a29da14020246e89ec
DIFF: https://github.com/llvm/llvm-project/commit/f1faba25433c971f024dd8a29da14020246e89ec.diff
LOG: [InstCombine][X86] Add constant folding for PMADDWD/PMADDUBSW intrinsics
Added:
Modified:
llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index 48e24a123417e..c4b3ee5cec9fd 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -503,13 +503,15 @@ static Value *simplifyX86pack(IntrinsicInst &II,
}
static Value *simplifyX86pmadd(IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder) {
+ InstCombiner::BuilderTy &Builder,
+ bool IsPMADDWD) {
Value *Arg0 = II.getArgOperand(0);
Value *Arg1 = II.getArgOperand(1);
auto *ResTy = cast<FixedVectorType>(II.getType());
[[maybe_unused]] auto *ArgTy = cast<FixedVectorType>(Arg0->getType());
- assert(ArgTy->getNumElements() == (2 * ResTy->getNumElements()) &&
+ unsigned NumDstElts = ResTy->getNumElements();
+ assert(ArgTy->getNumElements() == (2 * NumDstElts) &&
ResTy->getScalarSizeInBits() == (2 * ArgTy->getScalarSizeInBits()) &&
"Unexpected PMADD types");
@@ -517,7 +519,37 @@ static Value *simplifyX86pmadd(IntrinsicInst &II,
if (isa<ConstantAggregateZero>(Arg0) || isa<ConstantAggregateZero>(Arg1))
return ConstantAggregateZero::get(ResTy);
- return nullptr;
+ // Constant folding.
+ if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1))
+ return nullptr;
+
+ // Split Lo/Hi elements pairs, extend and add together.
+ // PMADDWD(X,Y) =
+ // add(mul(sext(lhs[0]),sext(rhs[0])),mul(sext(lhs[1]),sext(rhs[1])))
+ // PMADDUBSW(X,Y) =
+ // sadd_sat(mul(zext(lhs[0]),sext(rhs[0])),mul(zext(lhs[1]),sext(rhs[1])))
+ SmallVector<int> LoMask, HiMask;
+ for (unsigned I = 0; I != NumDstElts; ++I) {
+ LoMask.push_back(2 * I + 0);
+ HiMask.push_back(2 * I + 1);
+ }
+
+ auto *LHSLo = Builder.CreateShuffleVector(Arg0, LoMask);
+ auto *LHSHi = Builder.CreateShuffleVector(Arg0, HiMask);
+ auto *RHSLo = Builder.CreateShuffleVector(Arg1, LoMask);
+ auto *RHSHi = Builder.CreateShuffleVector(Arg1, HiMask);
+
+ auto LHSCast =
+ IsPMADDWD ? Instruction::CastOps::SExt : Instruction::CastOps::ZExt;
+ LHSLo = Builder.CreateCast(LHSCast, LHSLo, ResTy);
+ LHSHi = Builder.CreateCast(LHSCast, LHSHi, ResTy);
+ RHSLo = Builder.CreateCast(Instruction::CastOps::SExt, RHSLo, ResTy);
+ RHSHi = Builder.CreateCast(Instruction::CastOps::SExt, RHSHi, ResTy);
+ Value *Lo = Builder.CreateMul(LHSLo, RHSLo);
+ Value *Hi = Builder.CreateMul(LHSHi, RHSHi);
+ return IsPMADDWD
+ ? Builder.CreateAdd(Lo, Hi)
+ : Builder.CreateIntrinsic(ResTy, Intrinsic::sadd_sat, {Lo, Hi});
}
static Value *simplifyX86movmsk(const IntrinsicInst &II,
@@ -2499,7 +2531,7 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
case Intrinsic::x86_sse2_pmadd_wd:
case Intrinsic::x86_avx2_pmadd_wd:
case Intrinsic::x86_avx512_pmaddw_d_512:
- if (Value *V = simplifyX86pmadd(II, IC.Builder)) {
+ if (Value *V = simplifyX86pmadd(II, IC.Builder, true)) {
return IC.replaceInstUsesWith(II, V);
}
break;
@@ -2507,7 +2539,7 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
case Intrinsic::x86_avx2_pmadd_ub_sw:
case Intrinsic::x86_avx512_pmaddubs_w_512:
- if (Value *V = simplifyX86pmadd(II, IC.Builder)) {
+ if (Value *V = simplifyX86pmadd(II, IC.Builder, false)) {
return IC.replaceInstUsesWith(II, V);
}
break;
diff --git a/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll b/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
index 967664ca5a61c..0a53032920d12 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
@@ -117,8 +117,7 @@ define <32 x i16> @zero_pmaddubsw_512_commute(<64 x i8> %a0) {
define <8 x i16> @fold_pmaddubsw_128() {
; CHECK-LABEL: @fold_pmaddubsw_128(
-; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
-; CHECK-NEXT: ret <8 x i16> [[TMP1]]
+; CHECK-NEXT: ret <8 x i16> <i16 -32768, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450>
;
%1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <16 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
ret <8 x i16> %1
@@ -126,8 +125,7 @@ define <8 x i16> @fold_pmaddubsw_128() {
define <16 x i16> @fold_pmaddubsw_256() {
; CHECK-LABEL: @fold_pmaddubsw_256(
-; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>, <32 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>)
-; CHECK-NEXT: ret <16 x i16> [[TMP1]]
+; CHECK-NEXT: ret <16 x i16> <i16 -32768, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450, i16 -256, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450>
;
%1 = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>, <32 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>)
ret <16 x i16> %1
@@ -135,8 +133,7 @@ define <16 x i16> @fold_pmaddubsw_256() {
define <32 x i16> @fold_pmaddubsw_512() {
; CHECK-LABEL: @fold_pmaddubsw_512(
-; CHECK-NEXT: [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <64 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
-; CHECK-NEXT: ret <32 x i16> [[TMP1]]
+; CHECK-NEXT: ret <32 x i16> <i16 -32768, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450, i16 -256, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450, i16 -256, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450, i16 -32768, i16 18, i16 50, i16 1694, i16 162, i16 242, i16 338, i16 450>
;
%1 = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> <i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15>, <64 x i8> <i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -1, i8 -1, i8 2, i8 3, i8 4, i8 5, i8 -6, i8 7, i8 8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 -128, i8 -128, i8 3, i8 4, i8 5, i8 6, i8 7, i8 -8, i8 9, i8 10, i8 11, i8 12, i8 13, i8 14, i8 15, i8 16>)
ret <32 x i16> %1
diff --git a/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll b/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll
index ad7b20ba6141b..ccf0d6282ddb2 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll
@@ -117,8 +117,7 @@ define <16 x i32> @zero_pmaddwd_512_commute(<32 x i16> %a0) {
define <4 x i32> @fold_pmaddwd_128() {
; CHECK-LABEL: @fold_pmaddwd_128(
-; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
-; CHECK-NEXT: ret <4 x i32> [[TMP1]]
+; CHECK-NEXT: ret <4 x i32> <i32 19, i32 -229364, i32 -21, i32 -491429>
;
%1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
ret <4 x i32> %1
@@ -126,8 +125,7 @@ define <4 x i32> @fold_pmaddwd_128() {
define <8 x i32> @fold_pmaddwd_256() {
; CHECK-LABEL: @fold_pmaddwd_256(
-; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
-; CHECK-NEXT: ret <8 x i32> [[TMP1]]
+; CHECK-NEXT: ret <8 x i32> <i32 -7, i32 32762, i32 91, i32 32750, i32 -239, i32 687938, i32 -451, i32 -32756>
;
%1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
ret <8 x i32> %1
@@ -135,8 +133,7 @@ define <8 x i32> @fold_pmaddwd_256() {
define <16 x i32> @fold_pmaddwd_512() {
; CHECK-LABEL: @fold_pmaddwd_512(
-; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
-; CHECK-NEXT: ret <16 x i32> [[TMP1]]
+; CHECK-NEXT: ret <16 x i32> <i32 -7, i32 32762, i32 91, i32 32750, i32 -239, i32 687938, i32 -451, i32 -32756, i32 -7, i32 32762, i32 91, i32 32750, i32 -239, i32 687938, i32 -451, i32 -32756>
;
%1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
ret <16 x i32> %1
More information about the llvm-commits
mailing list