[llvm] 7bb251a - [InstCombine][X86] Add constant folding for PMULH/PMULHU/PMULHRS intrinsics
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 3 07:35:59 PDT 2024
Author: Simon Pilgrim
Date: 2024-07-03T15:34:28+01:00
New Revision: 7bb251a91a4f57aed458aa0572c135b5374cd2f2
URL: https://github.com/llvm/llvm-project/commit/7bb251a91a4f57aed458aa0572c135b5374cd2f2
DIFF: https://github.com/llvm/llvm-project/commit/7bb251a91a4f57aed458aa0572c135b5374cd2f2.diff
LOG: [InstCombine][X86] Add constant folding for PMULH/PMULHU/PMULHRS intrinsics
Added:
Modified:
llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
llvm/test/Transforms/InstCombine/X86/x86-pmulh.ll
llvm/test/Transforms/InstCombine/X86/x86-pmulhrs.ll
llvm/test/Transforms/InstCombine/X86/x86-pmulhu.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index 7ac149852be97..6d4734d477b3e 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -503,13 +503,15 @@ static Value *simplifyX86pack(IntrinsicInst &II,
}
static Value *simplifyX86pmulh(IntrinsicInst &II,
- InstCombiner::BuilderTy &Builder) {
+ InstCombiner::BuilderTy &Builder, bool IsSigned,
+ bool IsRounding) {
Value *Arg0 = II.getArgOperand(0);
Value *Arg1 = II.getArgOperand(1);
auto *ResTy = cast<FixedVectorType>(II.getType());
- [[maybe_unused]] auto *ArgTy = cast<FixedVectorType>(Arg0->getType());
+ auto *ArgTy = cast<FixedVectorType>(Arg0->getType());
assert(ArgTy == ResTy && ResTy->getScalarSizeInBits() == 16 &&
"Unexpected PMULH types");
+ assert((!IsRounding || IsSigned) && "PMULHRS instruction must be signed");
// Multiply by undef -> zero (NOT undef!) as other arg could still be zero.
if (isa<UndefValue>(Arg0) || isa<UndefValue>(Arg1))
@@ -519,8 +521,33 @@ static Value *simplifyX86pmulh(IntrinsicInst &II,
if (isa<ConstantAggregateZero>(Arg0) || isa<ConstantAggregateZero>(Arg1))
return ConstantAggregateZero::get(ResTy);
- // TODO: Constant folding.
- return nullptr;
+ // Constant folding.
+ if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1))
+ return nullptr;
+
+ // Extend to twice the width and multiply.
+ auto Cast =
+ IsSigned ? Instruction::CastOps::SExt : Instruction::CastOps::ZExt;
+ auto *ExtTy = FixedVectorType::getExtendedElementVectorType(ArgTy);
+ Value *LHS = Builder.CreateCast(Cast, Arg0, ExtTy);
+ Value *RHS = Builder.CreateCast(Cast, Arg1, ExtTy);
+ Value *Mul = Builder.CreateMul(LHS, RHS);
+
+ if (IsRounding) {
+ // PMULHRSW: truncate to vXi18 of the most significant bits, add one and
+ // extract bits[16:1].
+ auto *RndEltTy = IntegerType::get(ExtTy->getContext(), 18);
+ auto *RndTy = FixedVectorType::get(RndEltTy, ExtTy);
+ Mul = Builder.CreateLShr(Mul, 14);
+ Mul = Builder.CreateTrunc(Mul, RndTy);
+ Mul = Builder.CreateAdd(Mul, ConstantInt::get(RndTy, 1));
+ Mul = Builder.CreateLShr(Mul, 1);
+ } else {
+ // PMULH/PMULHU: extract the vXi16 most significant bits.
+ Mul = Builder.CreateLShr(Mul, 16);
+ }
+
+ return Builder.CreateTrunc(Mul, ResTy);
}
static Value *simplifyX86pmadd(IntrinsicInst &II,
@@ -2592,13 +2619,23 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
case Intrinsic::x86_sse2_pmulh_w:
case Intrinsic::x86_avx2_pmulh_w:
case Intrinsic::x86_avx512_pmulh_w_512:
+ if (Value *V = simplifyX86pmulh(II, IC.Builder, true, false)) {
+ return IC.replaceInstUsesWith(II, V);
+ }
+ break;
+
case Intrinsic::x86_sse2_pmulhu_w:
case Intrinsic::x86_avx2_pmulhu_w:
case Intrinsic::x86_avx512_pmulhu_w_512:
+ if (Value *V = simplifyX86pmulh(II, IC.Builder, false, false)) {
+ return IC.replaceInstUsesWith(II, V);
+ }
+ break;
+
case Intrinsic::x86_ssse3_pmul_hr_sw_128:
case Intrinsic::x86_avx2_pmul_hr_sw:
case Intrinsic::x86_avx512_pmul_hr_sw_512:
- if (Value *V = simplifyX86pmulh(II, IC.Builder)) {
+ if (Value *V = simplifyX86pmulh(II, IC.Builder, true, true)) {
return IC.replaceInstUsesWith(II, V);
}
break;
diff --git a/llvm/test/Transforms/InstCombine/X86/x86-pmulh.ll b/llvm/test/Transforms/InstCombine/X86/x86-pmulh.ll
index d6a06e7d08358..53b15383aec9a 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-pmulh.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-pmulh.ll
@@ -111,8 +111,7 @@ define <32 x i16> @zero_pmulh_512_commute(<32 x i16> %a0) {
define <8 x i16> @fold_pmulh_128() {
; CHECK-LABEL: @fold_pmulh_128(
-; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.x86.sse2.pmulh.w(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
-; CHECK-NEXT: ret <8 x i16> [[TMP1]]
+; CHECK-NEXT: ret <8 x i16> <i16 0, i16 0, i16 -2, i16 -2, i16 0, i16 -1, i16 -4, i16 -4>
;
%1 = call <8 x i16> @llvm.x86.sse2.pmulh.w(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
ret <8 x i16> %1
@@ -120,8 +119,7 @@ define <8 x i16> @fold_pmulh_128() {
define <16 x i16> @fold_pmulh_256() {
; CHECK-LABEL: @fold_pmulh_256(
-; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmulh.w(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
-; CHECK-NEXT: ret <16 x i16> [[TMP1]]
+; CHECK-NEXT: ret <16 x i16> <i16 0, i16 -1, i16 -1, i16 1, i16 0, i16 0, i16 -3, i16 3, i16 -1, i16 -1, i16 4, i16 5, i16 -1, i16 -1, i16 6, i16 -8>
;
%1 = call <16 x i16> @llvm.x86.avx2.pmulh.w(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
ret <16 x i16> %1
@@ -129,8 +127,7 @@ define <16 x i16> @fold_pmulh_256() {
define <32 x i16> @fold_pmulh_512() {
; CHECK-LABEL: @fold_pmulh_512(
-; CHECK-NEXT: [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmulh.w.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
-; CHECK-NEXT: ret <32 x i16> [[TMP1]]
+; CHECK-NEXT: ret <32 x i16> <i16 0, i16 -1, i16 -1, i16 1, i16 0, i16 0, i16 -3, i16 3, i16 -1, i16 -1, i16 4, i16 5, i16 -1, i16 -1, i16 6, i16 -8, i16 0, i16 -1, i16 -1, i16 1, i16 0, i16 0, i16 -3, i16 3, i16 -1, i16 -1, i16 4, i16 5, i16 -1, i16 -1, i16 6, i16 -8>
;
%1 = call <32 x i16> @llvm.x86.avx512.pmulh.w.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
ret <32 x i16> %1
diff --git a/llvm/test/Transforms/InstCombine/X86/x86-pmulhrs.ll b/llvm/test/Transforms/InstCombine/X86/x86-pmulhrs.ll
index 2c42534cae8b1..acc3fd0803365 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-pmulhrs.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-pmulhrs.ll
@@ -111,8 +111,7 @@ define <32 x i16> @zero_pmulh_512_commute(<32 x i16> %a0) {
define <8 x i16> @fold_pmulh_128() {
; CHECK-LABEL: @fold_pmulh_128(
-; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
-; CHECK-NEXT: ret <8 x i16> [[TMP1]]
+; CHECK-NEXT: ret <8 x i16> <i16 0, i16 0, i16 -3, i16 -4, i16 0, i16 0, i16 -7, i16 -8>
;
%1 = call <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
ret <8 x i16> %1
@@ -120,8 +119,7 @@ define <8 x i16> @fold_pmulh_128() {
define <16 x i16> @fold_pmulh_256() {
; CHECK-LABEL: @fold_pmulh_256(
-; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
-; CHECK-NEXT: ret <16 x i16> [[TMP1]]
+; CHECK-NEXT: ret <16 x i16> <i16 0, i16 0, i16 -2, i16 3, i16 0, i16 0, i16 -6, i16 7, i16 0, i16 0, i16 10, i16 11, i16 0, i16 0, i16 14, i16 -15>
;
%1 = call <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
ret <16 x i16> %1
@@ -129,8 +127,7 @@ define <16 x i16> @fold_pmulh_256() {
define <32 x i16> @fold_pmulh_512() {
; CHECK-LABEL: @fold_pmulh_512(
-; CHECK-NEXT: [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmul.hr.sw.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
-; CHECK-NEXT: ret <32 x i16> [[TMP1]]
+; CHECK-NEXT: ret <32 x i16> <i16 0, i16 0, i16 -2, i16 3, i16 0, i16 0, i16 -6, i16 7, i16 0, i16 0, i16 10, i16 11, i16 0, i16 0, i16 14, i16 -15, i16 0, i16 0, i16 -2, i16 3, i16 0, i16 0, i16 -6, i16 7, i16 0, i16 0, i16 10, i16 11, i16 0, i16 0, i16 14, i16 -15>
;
%1 = call <32 x i16> @llvm.x86.avx512.pmul.hr.sw.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
ret <32 x i16> %1
diff --git a/llvm/test/Transforms/InstCombine/X86/x86-pmulhu.ll b/llvm/test/Transforms/InstCombine/X86/x86-pmulhu.ll
index 81b890b7df6e6..52945ce82a183 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-pmulhu.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-pmulhu.ll
@@ -111,8 +111,7 @@ define <32 x i16> @zero_pmulhu_512_commute(<32 x i16> %a0) {
define <8 x i16> @fold_pmulhu_128() {
; CHECK-LABEL: @fold_pmulhu_128(
-; CHECK-NEXT: [[TMP1:%.*]] = call <8 x i16> @llvm.x86.sse2.pmulhu.w(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
-; CHECK-NEXT: ret <8 x i16> [[TMP1]]
+; CHECK-NEXT: ret <8 x i16> <i16 -6, i16 0, i16 1, i16 32763, i16 -14, i16 5, i16 3, i16 32757>
;
%1 = call <8 x i16> @llvm.x86.sse2.pmulhu.w(<8 x i16> <i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8>, <8 x i16> <i16 -5, i16 7, i16 -32768, i16 32765, i16 -9, i16 -11, i16 -32763, i16 32761>)
ret <8 x i16> %1
@@ -120,8 +119,7 @@ define <8 x i16> @fold_pmulhu_128() {
define <16 x i16> @fold_pmulhu_256() {
; CHECK-LABEL: @fold_pmulhu_256(
-; CHECK-NEXT: [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmulhu.w(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
-; CHECK-NEXT: ret <16 x i16> [[TMP1]]
+; CHECK-NEXT: ret <16 x i16> <i16 0, i16 6, i16 1, i16 1, i16 -13, i16 -16, i16 3, i16 3, i16 12, i16 8, i16 -32766, i16 5, i16 16, i16 12, i16 -32764, i16 32748>
;
%1 = call <16 x i16> @llvm.x86.avx2.pmulhu.w(<16 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>, <16 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>)
ret <16 x i16> %1
@@ -129,8 +127,7 @@ define <16 x i16> @fold_pmulhu_256() {
define <32 x i16> @fold_pmulhu_512() {
; CHECK-LABEL: @fold_pmulhu_512(
-; CHECK-NEXT: [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmulhu.w.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
-; CHECK-NEXT: ret <32 x i16> [[TMP1]]
+; CHECK-NEXT: ret <32 x i16> <i16 0, i16 6, i16 1, i16 1, i16 -13, i16 -16, i16 3, i16 3, i16 12, i16 8, i16 -32766, i16 5, i16 16, i16 12, i16 -32764, i16 32748, i16 0, i16 6, i16 1, i16 1, i16 -13, i16 -16, i16 3, i16 3, i16 12, i16 8, i16 -32766, i16 5, i16 16, i16 12, i16 -32764, i16 32748>
;
%1 = call <32 x i16> @llvm.x86.avx512.pmulhu.w.512(<32 x i16> <i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15, i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756>, <32 x i16> <i16 -5, i16 7, i16 -32768, i16 32766, i16 -9, i16 -11, i16 -32764, i16 32762, i16 13, i16 -15, i16 -32760, i16 32758, i16 17, i16 -19, i16 -32756, i16 32756, i16 0, i16 -1, i16 2, i16 3, i16 -4, i16 -5, i16 6, i16 7, i16 -8, i16 9, i16 -10, i16 11, i16 -12, i16 13, i16 -14, i16 -15>)
ret <32 x i16> %1
More information about the llvm-commits
mailing list