[llvm] [LV] Bundle (partial) reductions with a mul of a constant (PR #162503)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 24 06:44:34 PDT 2025


================
@@ -3668,18 +3701,23 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
   // variants.
   if (Sub)
     return nullptr;
-  // Match reduce.add(ext(mul(ext(A), ext(B)))).
-  // All extend recipes must have same opcode or A == B
-  // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
-  if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
-                                      m_ZExtOrSExt(m_VPValue()))))) {
+
+  // Match reduce.add(ext(mul(A, B))).
+  if (match(VecOp, m_ZExtOrSExt(m_Mul(m_VPValue(A), m_VPValue(B))))) {
     auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
     auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
-    auto *Ext0 =
-        cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
-    auto *Ext1 =
-        cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
-    if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
+    auto *Ext0 = dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
+    auto *Ext1 = dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
+
+    // reduce.add(ext(mul(ext, const)))
+    // -> reduce.add(ext(mul(ext, ext(const))))
+    ExtendAndReplaceConstantOp(Ext0, Ext1, B, Mul);
+
+    // Match reduce.add(ext(mul(ext(A), ext(B))))
+    // All extend recipes must have same opcode or A == B
+    // which can be transformed to reduce.add(zext(mul(sext(A), sext(B)))).
----------------
sdesmalen-arm wrote:

I know you've just copied this from above, but this comment is not accurate.

The cases it tries to handle are:
```
reduce.add(zext(mul(zext(a), zext(b))))
-> reduce.add(mul(wider_zext(a), wider_zext(b)))

reduce.add(sext(mul(sext(a), sext(b))))
-> reduce.add(mul(wider_sext(a), wider_sext(b)))
```

and the other case (and reason for checking `Ext0 == Ext1`) is because that would mean the `mul` is non-negative which means that the final zero-extend can be folded away, i.e.
```
reduce.add(zext(mul(sext(a), sext(a)))) // result of mul is nneg
-> reduce.add(mul(wider_sext(a), wider_sext(a)))
```

https://github.com/llvm/llvm-project/pull/162503


More information about the llvm-commits mailing list