[llvm] 9de14e2 - [InstCombine][X86] Add zero arg handling for PMADDWD/PMADDUBSW intrinsics

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 28 03:16:55 PDT 2024


Author: Simon Pilgrim
Date: 2024-06-28T11:16:32+01:00
New Revision: 9de14e24443046f5df39dad864af0bcdd85b53e0

URL: https://github.com/llvm/llvm-project/commit/9de14e24443046f5df39dad864af0bcdd85b53e0
DIFF: https://github.com/llvm/llvm-project/commit/9de14e24443046f5df39dad864af0bcdd85b53e0.diff

LOG: [InstCombine][X86] Add zero arg handling for PMADDWD/PMADDUBSW intrinsics

PMADDWD/PMADDUBSW - multiply by zero folds

Initial setup to handle future PMADDWD/PMADDUBSW simplification / constant folding

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
    llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
    llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index fe7213fc8721e..48e24a123417e 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -502,6 +502,24 @@ static Value *simplifyX86pack(IntrinsicInst &II,
   return Builder.CreateTrunc(Shuffle, ResTy);
 }
 
+static Value *simplifyX86pmadd(IntrinsicInst &II,
+                               InstCombiner::BuilderTy &Builder) {
+  Value *Arg0 = II.getArgOperand(0);
+  Value *Arg1 = II.getArgOperand(1);
+  auto *ResTy = cast<FixedVectorType>(II.getType());
+  [[maybe_unused]] auto *ArgTy = cast<FixedVectorType>(Arg0->getType());
+
+  assert(ArgTy->getNumElements() == (2 * ResTy->getNumElements()) &&
+         ResTy->getScalarSizeInBits() == (2 * ArgTy->getScalarSizeInBits()) &&
+         "Unexpected PMADD types");
+
+  // Multiply by zero.
+  if (isa<ConstantAggregateZero>(Arg0) || isa<ConstantAggregateZero>(Arg1))
+    return ConstantAggregateZero::get(ResTy);
+
+  return nullptr;
+}
+
 static Value *simplifyX86movmsk(const IntrinsicInst &II,
                                 InstCombiner::BuilderTy &Builder) {
   Value *Arg = II.getArgOperand(0);
@@ -2478,6 +2496,22 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
     }
     break;
 
+  case Intrinsic::x86_sse2_pmadd_wd:
+  case Intrinsic::x86_avx2_pmadd_wd:
+  case Intrinsic::x86_avx512_pmaddw_d_512:
+    if (Value *V = simplifyX86pmadd(II, IC.Builder)) {
+      return IC.replaceInstUsesWith(II, V);
+    }
+    break;
+
+  case Intrinsic::x86_ssse3_pmadd_ub_sw_128:
+  case Intrinsic::x86_avx2_pmadd_ub_sw:
+  case Intrinsic::x86_avx512_pmaddubs_w_512:
+    if (Value *V = simplifyX86pmadd(II, IC.Builder)) {
+      return IC.replaceInstUsesWith(II, V);
+    }
+    break;
+
   case Intrinsic::x86_pclmulqdq:
   case Intrinsic::x86_pclmulqdq_256:
   case Intrinsic::x86_pclmulqdq_512: {

diff  --git a/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll b/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
index 6724b8af8fca9..4f5b58f79a29c 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
@@ -38,8 +38,7 @@ define <32 x i16> @undef_pmaddubsw_512() {
 
 define <8 x i16> @zero_pmaddubsw_128(<16 x i8> %a0) {
 ; CHECK-LABEL: @zero_pmaddubsw_128(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> [[A0:%.*]], <16 x i8> zeroinitializer)
-; CHECK-NEXT:    ret <8 x i16> [[TMP1]]
+; CHECK-NEXT:    ret <8 x i16> zeroinitializer
 ;
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> %a0, <16 x i8> zeroinitializer)
   ret <8 x i16> %1
@@ -47,8 +46,7 @@ define <8 x i16> @zero_pmaddubsw_128(<16 x i8> %a0) {
 
 define <8 x i16> @zero_pmaddubsw_128_commute(<16 x i8> %a0) {
 ; CHECK-LABEL: @zero_pmaddubsw_128_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> zeroinitializer, <16 x i8> [[A0:%.*]])
-; CHECK-NEXT:    ret <8 x i16> [[TMP1]]
+; CHECK-NEXT:    ret <8 x i16> zeroinitializer
 ;
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> zeroinitializer, <16 x i8> %a0)
   ret <8 x i16> %1
@@ -56,8 +54,7 @@ define <8 x i16> @zero_pmaddubsw_128_commute(<16 x i8> %a0) {
 
 define <16 x i16> @zero_pmaddubsw_256(<32 x i8>%a0) {
 ; CHECK-LABEL: @zero_pmaddubsw_256(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> [[A0:%.*]], <32 x i8> zeroinitializer)
-; CHECK-NEXT:    ret <16 x i16> [[TMP1]]
+; CHECK-NEXT:    ret <16 x i16> zeroinitializer
 ;
   %1 = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> %a0, <32 x i8> zeroinitializer)
   ret <16 x i16> %1
@@ -65,8 +62,7 @@ define <16 x i16> @zero_pmaddubsw_256(<32 x i8>%a0) {
 
 define <16 x i16> @zero_pmaddubsw_256_commute(<32 x i8> %a0) {
 ; CHECK-LABEL: @zero_pmaddubsw_256_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> zeroinitializer, <32 x i8> [[A0:%.*]])
-; CHECK-NEXT:    ret <16 x i16> [[TMP1]]
+; CHECK-NEXT:    ret <16 x i16> zeroinitializer
 ;
   %1 = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> zeroinitializer, <32 x i8> %a0)
   ret <16 x i16> %1
@@ -74,17 +70,15 @@ define <16 x i16> @zero_pmaddubsw_256_commute(<32 x i8> %a0) {
 
 define <32 x i16> @zero_pmaddubsw_512(<64 x i8> %a0) {
 ; CHECK-LABEL: @zero_pmaddubsw_512(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> [[A0:%.*]], <64 x i8> zeroinitializer)
-; CHECK-NEXT:    ret <32 x i16> [[TMP1]]
+; CHECK-NEXT:    ret <32 x i16> zeroinitializer
 ;
   %1 = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> %a0, <64 x i8> zeroinitializer)
   ret <32 x i16> %1
 }
 
-define <32 x i16> @zero_pmaddubsw_512_commuite(<64 x i8> %a0) {
-; CHECK-LABEL: @zero_pmaddubsw_512_commuite(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> zeroinitializer, <64 x i8> [[A0:%.*]])
-; CHECK-NEXT:    ret <32 x i16> [[TMP1]]
+define <32 x i16> @zero_pmaddubsw_512_commute(<64 x i8> %a0) {
+; CHECK-LABEL: @zero_pmaddubsw_512_commute(
+; CHECK-NEXT:    ret <32 x i16> zeroinitializer
 ;
   %1 = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> zeroinitializer, <64 x i8> %a0)
   ret <32 x i16> %1

diff  --git a/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll b/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll
index ebcfbed598325..b91670e906be5 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll
@@ -38,8 +38,7 @@ define <16 x i32> @undef_pmaddwd_512() {
 
 define <4 x i32> @zero_pmaddwd_128(<8 x i16> %a0) {
 ; CHECK-LABEL: @zero_pmaddwd_128(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> [[A0:%.*]], <8 x i16> zeroinitializer)
-; CHECK-NEXT:    ret <4 x i32> [[TMP1]]
+; CHECK-NEXT:    ret <4 x i32> zeroinitializer
 ;
   %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a0, <8 x i16> zeroinitializer)
   ret <4 x i32> %1
@@ -47,8 +46,7 @@ define <4 x i32> @zero_pmaddwd_128(<8 x i16> %a0) {
 
 define <4 x i32> @zero_pmaddwd_128_commute(<8 x i16> %a0) {
 ; CHECK-LABEL: @zero_pmaddwd_128_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> zeroinitializer, <8 x i16> [[A0:%.*]])
-; CHECK-NEXT:    ret <4 x i32> [[TMP1]]
+; CHECK-NEXT:    ret <4 x i32> zeroinitializer
 ;
   %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> zeroinitializer, <8 x i16> %a0)
   ret <4 x i32> %1
@@ -56,8 +54,7 @@ define <4 x i32> @zero_pmaddwd_128_commute(<8 x i16> %a0) {
 
 define <8 x i32> @zero_pmaddwd_256(<16 x i16> %a0) {
 ; CHECK-LABEL: @zero_pmaddwd_256(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> [[A0:%.*]], <16 x i16> zeroinitializer)
-; CHECK-NEXT:    ret <8 x i32> [[TMP1]]
+; CHECK-NEXT:    ret <8 x i32> zeroinitializer
 ;
   %1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a0, <16 x i16> zeroinitializer)
   ret <8 x i32> %1
@@ -65,8 +62,7 @@ define <8 x i32> @zero_pmaddwd_256(<16 x i16> %a0) {
 
 define <8 x i32> @zero_pmaddwd_256_commute(<16 x i16> %a0) {
 ; CHECK-LABEL: @zero_pmaddwd_256_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> zeroinitializer, <16 x i16> [[A0:%.*]])
-; CHECK-NEXT:    ret <8 x i32> [[TMP1]]
+; CHECK-NEXT:    ret <8 x i32> zeroinitializer
 ;
   %1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> zeroinitializer, <16 x i16> %a0)
   ret <8 x i32> %1
@@ -74,17 +70,15 @@ define <8 x i32> @zero_pmaddwd_256_commute(<16 x i16> %a0) {
 
 define <16 x i32> @zero_pmaddwd_512(<32 x i16> %a0) {
 ; CHECK-LABEL: @zero_pmaddwd_512(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> [[A0:%.*]], <32 x i16> zeroinitializer)
-; CHECK-NEXT:    ret <16 x i32> [[TMP1]]
+; CHECK-NEXT:    ret <16 x i32> zeroinitializer
 ;
   %1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %a0, <32 x i16> zeroinitializer)
   ret <16 x i32> %1
 }
 
-define <16 x i32> @zero_pmaddwd_512_commuite(<32 x i16> %a0) {
-; CHECK-LABEL: @zero_pmaddwd_512_commuite(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> zeroinitializer, <32 x i16> [[A0:%.*]])
-; CHECK-NEXT:    ret <16 x i32> [[TMP1]]
+define <16 x i32> @zero_pmaddwd_512_commute(<32 x i16> %a0) {
+; CHECK-LABEL: @zero_pmaddwd_512_commute(
+; CHECK-NEXT:    ret <16 x i32> zeroinitializer
 ;
   %1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> zeroinitializer, <32 x i16> %a0)
   ret <16 x i32> %1


        


More information about the llvm-commits mailing list