[llvm] [X86] combine widening `shl` + adjacent addition into `VPMADDWD` (PR #179326)

via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 8 03:06:40 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: Folkert de Vries (folkertdev)

<details>
<summary>Changes</summary>

I added an optimization for `VPMADDWD` earlier in https://github.com/llvm/llvm-project/pull/174149. That one is used in the adler32 checksum. That PR missed another pattern, used in base64 decoding, that uses a `shl` instead of a `mul`, but also should optimize to `VPMADDWD`.

To make the shift semantically equal to the multiplication case, I'm bailing on shifts by more than 15, because `1 << 16` is not representable in an `i16`. 

code-wise I suspect that I'm missing some convenient way to access the integer values of a constant vector. 

---
Full diff: https://github.com/llvm/llvm-project/pull/179326.diff


2 Files Affected:

- (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+35-3) 
- (modified) llvm/test/CodeGen/X86/combine-pmadd.ll (+112) 


``````````diff
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 144d6451b981f..ee6a5ecc1165e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -58696,8 +58696,8 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
   //               (extract_elt Mul, 3),
   //               (extract_elt Mul, 5),
   //                   ...
-  // and identify Mul. Mul must be either ISD::MUL, or can be ISD::SIGN_EXTEND
-  // in which case we add a trivial multiplication by 1.
+  // and identify Mul. Mul must be either ISD::MUL, ISD::SHL, or can be
+  // ISD::SIGN_EXTEND in which case we add a trivial multiplication by 1.
   SDValue Mul;
   for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; i += 2) {
     SDValue Op0L = Op0->getOperand(i), Op1L = Op1->getOperand(i),
@@ -58728,7 +58728,7 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
       // with 2X number of vector elements than the BUILD_VECTOR.
       // Both extracts must be from same MUL.
       Mul = Vec0L;
-      if ((Mul.getOpcode() != ISD::MUL &&
+      if ((Mul.getOpcode() != ISD::MUL && Mul.getOpcode() != ISD::SHL &&
            Mul.getOpcode() != ISD::SIGN_EXTEND) ||
           Mul.getValueType().getVectorNumElements() != 2 * e)
         return SDValue();
@@ -58751,6 +58751,38 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
 
     N0 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(0));
     N1 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(1));
+  } else if (Mul.getOpcode() == ISD::SHL) {
+    SDValue ShVal = Mul.getOperand(0);
+    if (ShVal.getOpcode() != ISD::SIGN_EXTEND)
+      return SDValue();
+
+    N0 = ShVal.getOperand(0);
+    if (N0.getValueType() != TruncVT)
+      return SDValue();
+
+    unsigned NumElts = TruncVT.getVectorNumElements();
+    SmallVector<SDValue, 32> MulConsts;
+    MulConsts.reserve(NumElts);
+
+    auto *BV = dyn_cast<BuildVectorSDNode>(Mul.getOperand(1));
+    if (!BV || BV->getNumOperands() != NumElts)
+      return SDValue();
+
+    for (unsigned i = 0; i != NumElts; ++i) {
+      SDValue E = BV->getOperand(i);
+      if (E.isUndef())
+        return SDValue();
+      auto *C = dyn_cast<ConstantSDNode>(E);
+      if (!C)
+        return SDValue();
+      unsigned ShiftAmount = C->getZExtValue();
+      // A shift by more than 15 would overflow an i16.
+      if (ShiftAmount > 15)
+        return SDValue();
+      MulConsts.push_back(DAG.getConstant(1u << ShiftAmount, DL, MVT::i16));
+    }
+
+    N1 = DAG.getBuildVector(TruncVT, DL, MulConsts);
   } else {
     assert(Mul.getOpcode() == ISD::SIGN_EXTEND);
 
diff --git a/llvm/test/CodeGen/X86/combine-pmadd.ll b/llvm/test/CodeGen/X86/combine-pmadd.ll
index 231b9f97a5e3f..656aff18f02ef 100644
--- a/llvm/test/CodeGen/X86/combine-pmadd.ll
+++ b/llvm/test/CodeGen/X86/combine-pmadd.ll
@@ -360,3 +360,115 @@ define <8 x i32> @sext_pairwise_add(<16 x i16> %x) {
   %4 = add nsw <8 x i32> %2, %3
   ret <8 x i32> %4
 }
+
+define <8 x i32> @combine_with_mul(<16 x i16> %v) {
+; SSE-LABEL: combine_with_mul:
+; SSE:       # %bb.0: # %bb1
+; SSE-NEXT:    movdqa {{.*#+}} xmm2 = [4096,1,4096,1,4096,1,4096,1]
+; SSE-NEXT:    pmaddwd %xmm2, %xmm0
+; SSE-NEXT:    pmaddwd %xmm2, %xmm1
+; SSE-NEXT:    retq
+;
+; AVX1-LABEL: combine_with_mul:
+; AVX1:       # %bb.0: # %bb1
+; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm1
+; AVX1-NEXT:    vbroadcastss {{.*#+}} xmm2 = [4096,1,4096,1,4096,1,4096,1]
+; AVX1-NEXT:    vpmaddwd %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vpmaddwd %xmm2, %xmm0, %xmm0
+; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm0, %ymm0
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: combine_with_mul:
+; AVX2:       # %bb.0: # %bb1
+; AVX2-NEXT:    vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 # [4096,1,4096,1,4096,1,4096,1,4096,1,4096,1,4096,1,4096,1]
+; AVX2-NEXT:    retq
+bb1:
+  %0 = sext <16 x i16> %v to <16 x i32>
+  %1 = mul nsw <16 x i32> %0, <i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1, i32 4096, i32 1>
+  %2 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+  %3 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+  %4 = add <8 x i32> %2, %3
+  ret <8 x i32> %4
+}
+
+define <8 x i32> @combine_with_shl(<16 x i16> %v) {
+; SSE-LABEL: combine_with_shl:
+; SSE:       # %bb.0: # %bb1
+; SSE-NEXT:    movdqa {{.*#+}} xmm2 = [4096,1,4096,1,4096,1,4096,1]
+; SSE-NEXT:    pmaddwd %xmm2, %xmm0
+; SSE-NEXT:    pmaddwd %xmm2, %xmm1
+; SSE-NEXT:    retq
+;
+; AVX1-LABEL: combine_with_shl:
+; AVX1:       # %bb.0: # %bb1
+; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm1
+; AVX1-NEXT:    vbroadcastss {{.*#+}} xmm2 = [4096,1,4096,1,4096,1,4096,1]
+; AVX1-NEXT:    vpmaddwd %xmm2, %xmm1, %xmm1
+; AVX1-NEXT:    vpmaddwd %xmm2, %xmm0, %xmm0
+; AVX1-NEXT:    vinsertf128 $1, %xmm1, %ymm0, %ymm0
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: combine_with_shl:
+; AVX2:       # %bb.0: # %bb1
+; AVX2-NEXT:    vpmaddwd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %ymm0, %ymm0 # [4096,1,4096,1,4096,1,4096,1,4096,1,4096,1,4096,1,4096,1]
+; AVX2-NEXT:    retq
+bb1:
+  %0 = sext <16 x i16> %v to <16 x i32>
+  %1 = shl nsw <16 x i32> %0, <i32 12, i32 0, i32 12, i32 0, i32 12, i32 0, i32 12, i32 0, i32 12, i32 0, i32 12, i32 0, i32 12, i32 0, i32 12, i32 0>
+  %2 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+  %3 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+  %4 = add <8 x i32> %2, %3
+  ret <8 x i32> %4
+}
+
+; This version cannot use `vpmaddwd` because the multiply would overflow an i16.
+define <8 x i32> @combine_with_shl_overflow(<16 x i16> %v) {
+; SSE-LABEL: combine_with_shl_overflow:
+; SSE:       # %bb.0: # %bb1
+; SSE-NEXT:    pxor %xmm2, %xmm2
+; SSE-NEXT:    pxor %xmm4, %xmm4
+; SSE-NEXT:    punpckhwd {{.*#+}} xmm4 = xmm4[4],xmm1[4],xmm4[5],xmm1[5],xmm4[6],xmm1[6],xmm4[7],xmm1[7]
+; SSE-NEXT:    pxor %xmm3, %xmm3
+; SSE-NEXT:    punpcklwd {{.*#+}} xmm3 = xmm3[0],xmm1[0],xmm3[1],xmm1[1],xmm3[2],xmm1[2],xmm3[3],xmm1[3]
+; SSE-NEXT:    phaddd %xmm4, %xmm3
+; SSE-NEXT:    pxor %xmm1, %xmm1
+; SSE-NEXT:    punpckhwd {{.*#+}} xmm1 = xmm1[4],xmm0[4],xmm1[5],xmm0[5],xmm1[6],xmm0[6],xmm1[7],xmm0[7]
+; SSE-NEXT:    punpcklwd {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3]
+; SSE-NEXT:    phaddd %xmm1, %xmm2
+; SSE-NEXT:    movdqa %xmm2, %xmm0
+; SSE-NEXT:    movdqa %xmm3, %xmm1
+; SSE-NEXT:    retq
+;
+; AVX1-LABEL: combine_with_shl_overflow:
+; AVX1:       # %bb.0: # %bb1
+; AVX1-NEXT:    vextractf128 $1, %ymm0, %xmm1
+; AVX1-NEXT:    vpxor %xmm2, %xmm2, %xmm2
+; AVX1-NEXT:    vpunpcklwd {{.*#+}} xmm3 = xmm2[0],xmm1[0],xmm2[1],xmm1[1],xmm2[2],xmm1[2],xmm2[3],xmm1[3]
+; AVX1-NEXT:    vpunpcklwd {{.*#+}} xmm4 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3]
+; AVX1-NEXT:    vphaddd %xmm3, %xmm4, %xmm3
+; AVX1-NEXT:    vpunpckhwd {{.*#+}} xmm1 = xmm2[4],xmm1[4],xmm2[5],xmm1[5],xmm2[6],xmm1[6],xmm2[7],xmm1[7]
+; AVX1-NEXT:    vpunpckhwd {{.*#+}} xmm0 = xmm2[4],xmm0[4],xmm2[5],xmm0[5],xmm2[6],xmm0[6],xmm2[7],xmm0[7]
+; AVX1-NEXT:    vphaddd %xmm1, %xmm0, %xmm0
+; AVX1-NEXT:    vinsertf128 $1, %xmm0, %ymm0, %ymm0
+; AVX1-NEXT:    vinsertf128 $1, %xmm3, %ymm3, %ymm1
+; AVX1-NEXT:    vshufpd {{.*#+}} ymm0 = ymm1[0],ymm0[0],ymm1[3],ymm0[3]
+; AVX1-NEXT:    retq
+;
+; AVX2-LABEL: combine_with_shl_overflow:
+; AVX2:       # %bb.0: # %bb1
+; AVX2-NEXT:    vpmovzxwd {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm0
+; AVX2-NEXT:    vpmovzxwd {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero,xmm0[4],zero,xmm0[5],zero,xmm0[6],zero,xmm0[7],zero
+; AVX2-NEXT:    vpslld $16, %ymm0, %ymm0
+; AVX2-NEXT:    vpslld $16, %ymm1, %ymm1
+; AVX2-NEXT:    vphaddd %ymm0, %ymm1, %ymm0
+; AVX2-NEXT:    vpermq {{.*#+}} ymm0 = ymm0[0,2,1,3]
+; AVX2-NEXT:    retq
+bb1:
+  %0 = sext <16 x i16> %v to <16 x i32>
+  %1 = shl nsw <16 x i32> %0, splat (i32 16)
+  %2 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 0, i32 2, i32 4, i32 6, i32 8, i32 10, i32 12, i32 14>
+  %3 = shufflevector <16 x i32> %1, <16 x i32> poison, <8 x i32> <i32 1, i32 3, i32 5, i32 7, i32 9, i32 11, i32 13, i32 15>
+  %4 = add <8 x i32> %2, %3
+  ret <8 x i32> %4
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/179326


More information about the llvm-commits mailing list