[llvm] [X86] Allow input vector extracted from larger vector when combining to VPMADDUBSW (PR #89584)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 22 03:07:36 PDT 2024


https://github.com/phoebewang created https://github.com/llvm/llvm-project/pull/89584

Failed on main trunk: https://godbolt.org/z/edWMz8chE

>From c0335cb2589b9f3619efab3bb695c893de3e4ba8 Mon Sep 17 00:00:00 2001
From: Phoebe Wang <phoebe.wang at intel.com>
Date: Mon, 22 Apr 2024 17:55:43 +0800
Subject: [PATCH] [X86] Allow input vector extracted from larger vector when
 combining to VPMADDUBSW

Failed on main trunk: https://godbolt.org/z/edWMz8chE
---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 11 +++++++
 llvm/test/CodeGen/X86/pmaddubsw.ll      | 38 +++++++++++++++++++++++++
 2 files changed, 49 insertions(+)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 3a51c7c2ca854e..dd40d079c7e2f7 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -51841,6 +51841,17 @@ static SDValue detectPMADDUBSW(SDValue In, EVT VT, SelectionDAG &DAG,
       return SDValue();
   }
 
+  auto ExtractVec = [&DAG, &DL, NumElems](SDValue &Ext) {
+    EVT ExtVT = Ext.getValueType();
+    if (ExtVT.getVectorNumElements() != NumElems * 2) {
+      MVT NVT = MVT::getVectorVT(MVT::i8, NumElems * 2);
+      Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, NVT, Ext,
+                        DAG.getIntPtrConstant(0, DL));
+    }
+  };
+  ExtractVec(ZExtIn);
+  ExtractVec(SExtIn);
+
   auto PMADDBuilder = [](SelectionDAG &DAG, const SDLoc &DL,
                          ArrayRef<SDValue> Ops) {
     // Shrink by adding truncate nodes and let DAGCombine fold with the
diff --git a/llvm/test/CodeGen/X86/pmaddubsw.ll b/llvm/test/CodeGen/X86/pmaddubsw.ll
index e46a14673a5171..d6c9877cd99b63 100644
--- a/llvm/test/CodeGen/X86/pmaddubsw.ll
+++ b/llvm/test/CodeGen/X86/pmaddubsw.ll
@@ -469,3 +469,41 @@ define <8 x i16> @pmaddubsw_bad_indices(ptr %Aptr, ptr %Bptr) {
   %trunc = trunc <8 x i32> %min to <8 x i16>
   ret <8 x i16> %trunc
 }
+
+define <8 x i16> @pmaddubsw_large_vector(ptr %p1, ptr %p2) {
+; SSE-LABEL: pmaddubsw_large_vector:
+; SSE:       # %bb.0:
+; SSE-NEXT:    movdqa (%rdi), %xmm0
+; SSE-NEXT:    pmaddubsw (%rsi), %xmm0
+; SSE-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; SSE-NEXT:    retq
+;
+; AVX-LABEL: pmaddubsw_large_vector:
+; AVX:       # %bb.0:
+; AVX-NEXT:    vmovdqa (%rdi), %xmm0
+; AVX-NEXT:    vpmaddubsw (%rsi), %xmm0, %xmm0
+; AVX-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX-NEXT:    vpblendw {{.*#+}} xmm0 = xmm1[0,1],xmm0[2],xmm1[3,4],xmm0[5],xmm1[6],xmm0[7]
+; AVX-NEXT:    retq
+  %1 = load <64 x i8>, ptr %p1, align 64
+  %2 = shufflevector <64 x i8> %1, <64 x i8> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+  %3 = shufflevector <64 x i8> %1, <64 x i8> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+  %4 = load <32 x i8>, ptr %p2, align 64
+  %5 = shufflevector <32 x i8> %4, <32 x i8> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+  %6 = shufflevector <32 x i8> %4, <32 x i8> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+  %7 = sext <8 x i8> %5 to <8 x i32>
+  %8 = zext <8 x i8> %2 to <8 x i32>
+  %9 = mul nsw <8 x i32> %7, %8
+  %10 = sext <8 x i8> %6 to <8 x i32>
+  %11 = zext <8 x i8> %3 to <8 x i32>
+  %12 = mul nsw <8 x i32> %10, %11
+  %13 = add nsw <8 x i32> %9, %12
+  %14 = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %13, <8 x i32> <i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767, i32 32767>)
+  %15 = tail call <8 x i32> @llvm.smax.v8i32(<8 x i32> %14, <8 x i32> <i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768, i32 -32768>)
+  %16 = trunc <8 x i32> %15 to <8 x i16>
+  %17 = shufflevector <8 x i16> zeroinitializer, <8 x i16> %16, <8 x i32> <i32 0, i32 1, i32 10, i32 3, i32 4, i32 13, i32 6, i32 15>
+  ret <8 x i16> %17
+}
+
+declare <8 x i32> @llvm.smin.v8i32(<8 x i32>, <8 x i32>)
+declare <8 x i32> @llvm.smax.v8i32(<8 x i32>, <8 x i32>)



More information about the llvm-commits mailing list