[llvm] [DAGCombiner] Fold select into partial.reduce.add operands. (PR #167857)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Sun Nov 16 19:25:38 PST 2025


================
@@ -13018,22 +13018,37 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   return SDValue();
 }
 
-// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
+// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
 // -> partial_reduce_*mla(acc, a, b)
 //
-// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-// -> partial_reduce_*mla(acc, x, C)
+// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, splat(C))
 //
-// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
-// -> partial_reduce_fmla(acc, a, b)
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
+//
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
 SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDLoc DL(N);
   auto *Context = DAG.getContext();
   SDValue Acc = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
   SDValue Op2 = N->getOperand(2);
-
   unsigned Opc = Op1->getOpcode();
+
+  // Handle predication by moving the SELECT into the operand of the MUL.
+  SDValue Pred;
+  if (Opc == ISD::VSELECT) {
+    APInt C;
+    if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+        !C.isZero())
+      return SDValue();
----------------
lukel97 wrote:

```suggestion
  if (Opc == ISD::VSELECT && isNullOrNullSplat(Op1->getOperand(2))) {
```

https://github.com/llvm/llvm-project/pull/167857


More information about the llvm-commits mailing list