[llvm] c7e0f1e - [X86] Allow input vector extracted from larger vector when combining to VPMADDUBSW (#89584)

via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 22 06:04:10 PDT 2024


Author: Phoebe Wang
Date: 2024-04-22T21:04:05+08:00
New Revision: c7e0f1e988d73555d1da7474528996e748622f42

URL: https://github.com/llvm/llvm-project/commit/c7e0f1e988d73555d1da7474528996e748622f42
DIFF: https://github.com/llvm/llvm-project/commit/c7e0f1e988d73555d1da7474528996e748622f42.diff

LOG: [X86] Allow input vector extracted from larger vector when combining to VPMADDUBSW (#89584)

Failed on main trunk: https://godbolt.org/z/edWMz8chE

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/pmaddubsw.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 3a51c7c2ca854e..dd40d079c7e2f7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -51841,6 +51841,17 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
       return SDValue();
   }
 
+  auto ExtractVec = [&DAG, &DL, NumElems](SDValue &Ext) {
+    EVT ExtVT = Ext.getValueType();
+    if (ExtVT.getVectorNumElements() != NumElems * 2) {
+      MVT NVT = MVT::getVectorVT(MVT::i8, NumElems * 2);
+      Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, Ext,
+                        DAG.getIntPtrConstant(0, DL));
+    }
+  };
+  ExtractVec(ZExtIn);
+  ExtractVec(SExtIn);
+
   auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
                          ArrayRef<SDValue> Ops) {
     // Shrink by adding truncate nodes and let DAGCombine fold with the

diff  --git a/llvm/test/CodeGen/X86/pmaddubsw.ll b/llvm/test/CodeGen/X86/pmaddubsw.ll
index e46a14673a5171..d6c9877cd99b63 100644
--- a/llvm/test/CodeGen/X86/pmaddubsw.ll
+++ b/llvm/test/CodeGen/X86/pmaddubsw.ll
@@ -469,3 +469,41 @@ define <8 x i16> @pmaddubsw_bad_indices(ptr %Aptr, ptr %Bptr) {
   %trunc = trunc <8 x i32> %min to <8 x i16>
   ret <8 x i16> %trunc
 }
+
+define <8 x i16> @pmaddubsw_large_vector(ptr %p1, ptr %p2) {
+; SSE-LABEL: pmaddubsw_large_vector:
+; SSE:       # %bb.0:
+; SSE-NEXT:    movdqa (%rdi), %xmm0
+; SSE-NEXT:    pmaddubsw (%rsi), %xmm0
+; SSE-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; SSE-NEXT:    retq
+;
+; AVX-LABEL: pmaddubsw_large_vector:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vmovdqa (%rdi), %xmm0
+; AVX-NEXT:    vpmaddubsw (%rsi), %xmm0, %xmm0
+; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX-NEXT:    vpblendw {{.*#+}} xmm0 = xmm1[0,1],xmm0[2],xmm1[3,4],xmm0[5],xmm1[6],xmm0[7]
+; AVX-NEXT:    retq
+  %1 = load <64 x i8>, ptr %p1, align 64
+  %2 = shufflevector <64 x i8> %1, <64 x i8> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+  %3 = shufflevector <64 x i8> %1, <64 x i8> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+  %4 = load <32 x i8>, ptr %p2, align 64
+  %5 = shufflevector <32 x i8> %4, <32 x i8> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+  %6 = shufflevector <32 x i8> %4, <32 x i8> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+  %7 = sext <8 x i8> %5 to <8 x i32>
+  %8 = zext <8 x i8> %2 to <8 x i32>
+  %9 = mul nsw <8 x i32> %7, %8
+  %10 = sext <8 x i8> %6 to <8 x i32>
+  %11 = zext <8 x i8> %3 to <8 x i32>
+  %12 = mul nsw <8 x i32> %10, %11
+  %13 = add nsw <8 x i32> %9, %12
+  %14 = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %13, <8 x i32> <i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767>)
+  %15 = tail call <8 x i32> @llvm.smax.v8i32(<8 x i32> %14, <8 x i32> <i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768>)
+  %16 = trunc <8 x i32> %15 to <8 x i16>
+  %17 = shufflevector <8 x i16> zeroinitializer, <8 x i16> %16, <8 x i32> <i32 0, i32 1, i32 10, i32 3, i32 4, i32 13, i32 6, i32 15>
+  ret <8 x i16> %17
+}
+
+declare <8 x i32> @llvm.smin.v8i32(<8 x i32>, <8 x i32>)
+declare <8 x i32> @llvm.smax.v8i32(<8 x i32>, <8 x i32>)


        


More information about the llvm-commits mailing list