[llvm] d893ed7 - [InstCombine][X86] Add undef arg handling for PMADDWD/PMADDUBSW intrinsics

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 28 04:12:07 PDT 2024


Author: Simon Pilgrim
Date: 2024-06-28T12:10:39+01:00
New Revision: d893ed78718e25a982dcba9cdba2d78212b79353

URL: https://github.com/llvm/llvm-project/commit/d893ed78718e25a982dcba9cdba2d78212b79353
DIFF: https://github.com/llvm/llvm-project/commit/d893ed78718e25a982dcba9cdba2d78212b79353.diff

LOG: [InstCombine][X86] Add undef arg handling for PMADDWD/PMADDUBSW intrinsics

These fold to zero, not undef, as the other arg could still be zero.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
    llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
    llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
index c4b3ee5cec9fd..49f16833ee513 100644
--- a/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp
@@ -515,6 +515,10 @@ static Value *simplifyX86pmadd(IntrinsicInst &II,
          ResTy->getScalarSizeInBits() == (2 * ArgTy->getScalarSizeInBits()) &&
          "Unexpected PMADD types");
 
+  // Multiply by undef -> zero (NOT undef!) as other arg could still be zero.
+  if (isa<UndefValue>(Arg0) || isa<UndefValue>(Arg1))
+    return ConstantAggregateZero::get(ResTy);
+
   // Multiply by zero.
   if (isa<ConstantAggregateZero>(Arg0) || isa<ConstantAggregateZero>(Arg1))
     return ConstantAggregateZero::get(ResTy);

diff  --git a/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll b/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
index 0a53032920d12..6de8127d5e9be 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-pmaddubsw.ll
@@ -7,8 +7,7 @@
 
 define <8 x i16> @undef_pmaddubsw_128(<16 x i8> %a0) {
 ; CHECK-LABEL: @undef_pmaddubsw_128(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> [[A0:%.*]], <16 x i8> undef)
-; CHECK-NEXT:    ret <8 x i16> [[TMP1]]
+; CHECK-NEXT:    ret <8 x i16> zeroinitializer
 ;
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> %a0, <16 x i8> undef)
   ret <8 x i16> %1
@@ -16,8 +15,7 @@ define <8 x i16> @undef_pmaddubsw_128(<16 x i8> %a0) {
 
 define <8 x i16> @undef_pmaddubsw_128_commute(<16 x i8> %a0) {
 ; CHECK-LABEL: @undef_pmaddubsw_128_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> undef, <16 x i8> [[A0:%.*]])
-; CHECK-NEXT:    ret <8 x i16> [[TMP1]]
+; CHECK-NEXT:    ret <8 x i16> zeroinitializer
 ;
   %1 = call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> undef, <16 x i8> %a0)
   ret <8 x i16> %1
@@ -25,8 +23,7 @@ define <8 x i16> @undef_pmaddubsw_128_commute(<16 x i8> %a0) {
 
 define <16 x i16> @undef_pmaddubsw_256(<32 x i8> %a0) {
 ; CHECK-LABEL: @undef_pmaddubsw_256(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> [[A0:%.*]], <32 x i8> undef)
-; CHECK-NEXT:    ret <16 x i16> [[TMP1]]
+; CHECK-NEXT:    ret <16 x i16> zeroinitializer
 ;
   %1 = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> %a0, <32 x i8> undef)
   ret <16 x i16> %1
@@ -34,8 +31,7 @@ define <16 x i16> @undef_pmaddubsw_256(<32 x i8> %a0) {
 
 define <16 x i16> @undef_pmaddubsw_256_commute(<32 x i8> %a0) {
 ; CHECK-LABEL: @undef_pmaddubsw_256_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> undef, <32 x i8> [[A0:%.*]])
-; CHECK-NEXT:    ret <16 x i16> [[TMP1]]
+; CHECK-NEXT:    ret <16 x i16> zeroinitializer
 ;
   %1 = call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> undef, <32 x i8> %a0)
   ret <16 x i16> %1
@@ -43,8 +39,7 @@ define <16 x i16> @undef_pmaddubsw_256_commute(<32 x i8> %a0) {
 
 define <32 x i16> @undef_pmaddubsw_512(<64 x i8> %a0) {
 ; CHECK-LABEL: @undef_pmaddubsw_512(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> [[A0:%.*]], <64 x i8> undef)
-; CHECK-NEXT:    ret <32 x i16> [[TMP1]]
+; CHECK-NEXT:    ret <32 x i16> zeroinitializer
 ;
   %1 = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> %a0, <64 x i8> undef)
   ret <32 x i16> %1
@@ -52,8 +47,7 @@ define <32 x i16> @undef_pmaddubsw_512(<64 x i8> %a0) {
 
 define <32 x i16> @undef_pmaddubsw_512_commute(<64 x i8> %a0) {
 ; CHECK-LABEL: @undef_pmaddubsw_512_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> undef, <64 x i8> [[A0:%.*]])
-; CHECK-NEXT:    ret <32 x i16> [[TMP1]]
+; CHECK-NEXT:    ret <32 x i16> zeroinitializer
 ;
   %1 = call <32 x i16> @llvm.x86.avx512.pmaddubs.w.512(<64 x i8> undef, <64 x i8> %a0)
   ret <32 x i16> %1

diff  --git a/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll b/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll
index ccf0d6282ddb2..25849f70b6bbf 100644
--- a/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll
+++ b/llvm/test/Transforms/InstCombine/X86/x86-pmaddwd.ll
@@ -7,8 +7,7 @@
 
 define <4 x i32> @undef_pmaddwd_128(<8 x i16> %a0) {
 ; CHECK-LABEL: @undef_pmaddwd_128(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> [[A0:%.*]], <8 x i16> undef)
-; CHECK-NEXT:    ret <4 x i32> [[TMP1]]
+; CHECK-NEXT:    ret <4 x i32> zeroinitializer
 ;
   %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> %a0, <8 x i16> undef)
   ret <4 x i32> %1
@@ -16,8 +15,7 @@ define <4 x i32> @undef_pmaddwd_128(<8 x i16> %a0) {
 
 define <4 x i32> @undef_pmaddwd_128_commute(<8 x i16> %a0) {
 ; CHECK-LABEL: @undef_pmaddwd_128_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> undef, <8 x i16> [[A0:%.*]])
-; CHECK-NEXT:    ret <4 x i32> [[TMP1]]
+; CHECK-NEXT:    ret <4 x i32> zeroinitializer
 ;
   %1 = call <4 x i32> @llvm.x86.sse2.pmadd.wd(<8 x i16> undef, <8 x i16> %a0)
   ret <4 x i32> %1
@@ -25,8 +23,7 @@ define <4 x i32> @undef_pmaddwd_128_commute(<8 x i16> %a0) {
 
 define <8 x i32> @undef_pmaddwd_256(<16 x i16> %a0) {
 ; CHECK-LABEL: @undef_pmaddwd_256(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> [[A0:%.*]], <16 x i16> undef)
-; CHECK-NEXT:    ret <8 x i32> [[TMP1]]
+; CHECK-NEXT:    ret <8 x i32> zeroinitializer
 ;
   %1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a0, <16 x i16> undef)
   ret <8 x i32> %1
@@ -34,8 +31,7 @@ define <8 x i32> @undef_pmaddwd_256(<16 x i16> %a0) {
 
 define <8 x i32> @undef_pmaddwd_256_commute(<16 x i16> %a0) {
 ; CHECK-LABEL: @undef_pmaddwd_256_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> undef, <16 x i16> [[A0:%.*]])
-; CHECK-NEXT:    ret <8 x i32> [[TMP1]]
+; CHECK-NEXT:    ret <8 x i32> zeroinitializer
 ;
   %1 = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> undef, <16 x i16> %a0)
   ret <8 x i32> %1
@@ -43,8 +39,7 @@ define <8 x i32> @undef_pmaddwd_256_commute(<16 x i16> %a0) {
 
 define <16 x i32> @undef_pmaddwd_512(<32 x i16> %a0) {
 ; CHECK-LABEL: @undef_pmaddwd_512(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> [[A0:%.*]], <32 x i16> undef)
-; CHECK-NEXT:    ret <16 x i32> [[TMP1]]
+; CHECK-NEXT:    ret <16 x i32> zeroinitializer
 ;
   %1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> %a0, <32 x i16> undef)
   ret <16 x i32> %1
@@ -52,8 +47,7 @@ define <16 x i32> @undef_pmaddwd_512(<32 x i16> %a0) {
 
 define <16 x i32> @undef_pmaddwd_512_commute(<32 x i16> %a0) {
 ; CHECK-LABEL: @undef_pmaddwd_512_commute(
-; CHECK-NEXT:    [[TMP1:%.*]] = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> undef, <32 x i16> [[A0:%.*]])
-; CHECK-NEXT:    ret <16 x i32> [[TMP1]]
+; CHECK-NEXT:    ret <16 x i32> zeroinitializer
 ;
   %1 = call <16 x i32> @llvm.x86.avx512.pmaddw.d.512(<32 x i16> undef, <32 x i16> %a0)
   ret <16 x i32> %1


        


More information about the llvm-commits mailing list