[llvm] [X86] combine widening `shl` + adjacent addition into `VPMADDWD` (PR #179326)

Folkert de Vries via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 3 11:45:57 PST 2026


================
@@ -58751,6 +58751,38 @@ static SDValue matchPMADDWD(SelectionDAG &DAG, SDNode *N,
 
     N0 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(0));
     N1 = DAG.getNode(ISD::TRUNCATE, DL, TruncVT, Mul.getOperand(1));
+  } else if (Mul.getOpcode() == ISD::SHL) {
+    SDValue ShVal = Mul.getOperand(0);
+    if (ShVal.getOpcode() != ISD::SIGN_EXTEND)
+      return SDValue();
+
+    N0 = ShVal.getOperand(0);
+    if (N0.getValueType() != TruncVT)
+      return SDValue();
+
+    unsigned NumElts = TruncVT.getVectorNumElements();
+    SmallVector<SDValue, 32> MulConsts;
+    MulConsts.reserve(NumElts);
+
+    auto *BV = dyn_cast<BuildVectorSDNode>(Mul.getOperand(1));
+    if (!BV || BV->getNumOperands() != NumElts)
+      return SDValue();
+
+    for (unsigned i = 0; i != NumElts; ++i) {
+      SDValue E = BV->getOperand(i);
+      if (E.isUndef())
+        return SDValue();
+      auto *C = dyn_cast<ConstantSDNode>(E);
+      if (!C)
+        return SDValue();
+      unsigned ShiftAmount = C->getZExtValue();
+      // A shift by more than 15 would overflow an i16.
+      if (ShiftAmount > 15)
+        return SDValue();
+      MulConsts.push_back(DAG.getConstant(1u << ShiftAmount, DL, MVT::i16));
----------------
folkertdev wrote:

Would that be simpler? After a cursory look I think the check for the shift being at most 15 would still require explicit looping over each element.

https://github.com/llvm/llvm-project/pull/179326


More information about the llvm-commits mailing list