[llvm] [DAGCombiner] Fold select into partial.reduce.add operands. (PR #167857)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 17 06:20:20 PST 2025


https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/167857

>From edbebb4050242885528bd6445990c3b8b7ed4270 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Wed, 12 Nov 2025 17:12:33 +0000
Subject: [PATCH 1/3] [DAGCombiner] Fold select into partial.reduce.add
 operands.

This generates more optimal codegen when using partial reductions
with predication.

partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
-> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)

partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
-> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp |  71 ++++++--
 .../partial-reduction-add-predicated.ll       | 159 ++++++++++++++++++
 llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll |  39 +++--
 3 files changed, 242 insertions(+), 27 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index df353c4d91b1a..1cba1d4d5cc22 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13018,22 +13018,37 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
   return SDValue();
 }
 
-// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
+// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
 // -> partial_reduce_*mla(acc, a, b)
 //
-// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-// -> partial_reduce_*mla(acc, x, C)
+// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, splat(C))
 //
-// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
-// -> partial_reduce_fmla(acc, a, b)
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
+//
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
 SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDLoc DL(N);
   auto *Context = DAG.getContext();
   SDValue Acc = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
   SDValue Op2 = N->getOperand(2);
-
   unsigned Opc = Op1->getOpcode();
+
+  // Handle predication by moving the SELECT into the operand of the MUL.
+  SDValue Pred;
+  if (Opc == ISD::VSELECT) {
+    APInt C;
+    if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+        !C.isZero())
+      return SDValue();
+    Pred = Op1->getOperand(0);
+    Op1 = Op1->getOperand(1);
+    Opc = Op1->getOpcode();
+  }
+
   if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
     return SDValue();
 
@@ -13068,6 +13083,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue LHSExtOp = LHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
 
+  // Return 'select(P, Op, splat(0))' if P is nonzero,
+  // or 'P' otherwise.
+  auto tryPredicate = [&](SDValue P, SDValue Op) {
+    if (!P)
+      return Op;
+    EVT OpVT = Op.getValueType();
+    SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
+                                          : DAG.getConstant(0, DL, OpVT);
+    return DAG.getSelect(DL, OpVT, P, Op, Zero);
+  };
+
   // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
   // -> partial_reduce_*mla(acc, x, C)
   APInt C;
@@ -13090,8 +13116,10 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
             TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
       return SDValue();
 
+    SDValue Constant =
+        tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
     return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
-                       DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+                       Constant);
   }
 
   unsigned RHSOpcode = RHS->getOpcode();
@@ -13132,17 +13160,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
+  RHSExtOp = tryPredicate(Pred, RHSExtOp);
   return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
 }
 
-// partial.reduce.umla(acc, zext(op), splat(1))
-// -> partial.reduce.umla(acc, op, splat(trunc(1)))
-// partial.reduce.smla(acc, sext(op), splat(1))
-// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.*mla(acc, *ext(op), splat(1))
+// -> partial.reduce.*mla(acc, op, splat(trunc(1)))
 // partial.reduce.sumla(acc, sext(op), splat(1))
 // -> partial.reduce.smla(acc, op, splat(trunc(1)))
-// partial.reduce.fmla(acc, fpext(op), splat(1.0))
-// -> partial.reduce.fmla(acc, op, splat(1.0))
+//
+// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
+// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
 SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   SDLoc DL(N);
   SDValue Acc = N->getOperand(0);
@@ -13152,7 +13180,18 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
     return SDValue();
 
+  SDValue Pred;
   unsigned Op1Opcode = Op1.getOpcode();
+  if (Op1Opcode == ISD::VSELECT) {
+    APInt C;
+    if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+        !C.isZero())
+      return SDValue();
+    Pred = Op1->getOperand(0);
+    Op1 = Op1->getOperand(1);
+    Op1Opcode = Op1->getOpcode();
+  }
+
   if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
     return SDValue();
 
@@ -13181,6 +13220,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
                          ? DAG.getConstantFP(1, DL, UnextOp1VT)
                          : DAG.getConstant(1, DL, UnextOp1VT);
 
+  if (Pred) {
+    SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+                       ? DAG.getConstantFP(0, DL, UnextOp1VT)
+                       : DAG.getConstant(0, DL, UnextOp1VT);
+    Constant = DAG.getSelect(DL, UnextOp1VT, Pred, Constant, Zero);
+  }
   return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
                      Constant);
 }
diff --git a/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
new file mode 100644
index 0000000000000..24cdd0a852222
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
@@ -0,0 +1,159 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s | FileCheck %s
+
+target triple = "aarch64"
+
+define <4 x i32> @predicate_dot_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a, <16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v1.16b, v1.16b, #7
+; CHECK-NEXT:    cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %ext.2 = sext <16 x i8> %b to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, %ext.2
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <4 x i32> @predicate_dot_by_C_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v1.16b, v1.16b, #7
+; CHECK-NEXT:    movi v3.16b, #127
+; CHECK-NEXT:    cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, splat(i32 127)
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEXT:    sel z2.b, p0, z2.b, z3.b
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %ext.2 = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, %ext.2
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_by_C_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z2.b, p0/z, #127 // =0x7f
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, splat(i32 127)
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x i32> @predicate_ext_mul_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v3.16b, #1
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT:    ret
+ %ext = sext <16 x i8> %a to <16 x i32>
+ %sel = select <16 x i1> %p, <16 x i32> %ext, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_ext_mul_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z2.b, p0/z, #1 // =0x1
+; CHECK-NEXT:    sdot z0.s, z1.b, z2.b
+; CHECK-NEXT:    ret
+ %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %ext, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x float> @predicated_fdot_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a, <8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    shl v1.8h, v1.8h, #15
+; CHECK-NEXT:    cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    fdot z0.s, z2.h, z1.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+ %ext.1 = fpext <8 x half> %a to <8 x float>
+ %ext.2 = fpext <8 x half> %b to <8 x float>
+ %mul = fmul <8 x float> %ext.1, %ext.2
+ %sel = select <8 x i1> %p, <8 x float> %mul, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fdot_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEXT:    sel z2.h, p0, z2.h, z3.h
+; CHECK-NEXT:    fdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+ %ext.1 = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %ext.2 = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+ %mul = fmul <vscale x 8 x float> %ext.1, %ext.2
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %mul, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+define <4 x float> @predicated_fpext_fmul_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_fixed_length:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    movi v3.8h, #60, lsl #8
+; CHECK-NEXT:    // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT:    // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT:    shl v1.8h, v1.8h, #15
+; CHECK-NEXT:    cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT:    and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT:    fdot z0.s, z2.h, z1.h
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    ret
+ %ext = fpext <8 x half> %a to <8 x float>
+ %sel = select <8 x i1> %p, <8 x float> %ext, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fpext_fmul_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_scalable:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-NEXT:    fmov z2.h, p0/m, #1.00000000
+; CHECK-NEXT:    fdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+ %ext = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %ext, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+attributes #0 = { nounwind "target-features"="+sve,+dotprod" }
+attributes #1 = { nounwind "target-features"="+sve2p1,+dotprod" }
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index 72bf1fa9a8327..d6384a6913efe 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -996,20 +996,31 @@ entry:
 }
 
 define <vscale x 2 x i32> @partial_reduce_select(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b, <vscale x 8 x i1> %m) {
-; CHECK-LABEL: partial_reduce_select:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetvli a0, zero, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v8
-; CHECK-NEXT:    vsext.vf2 v14, v9
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.v.i v8, 0
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, mu
-; CHECK-NEXT:    vwmul.vv v8, v12, v14, v0.t
-; CHECK-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vadd.vv v8, v11, v8
-; CHECK-NEXT:    vadd.vv v9, v9, v10
-; CHECK-NEXT:    vadd.vv v8, v9, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: partial_reduce_select:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetvli a0, zero, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vsext.vf2 v14, v9
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.v.i v8, 0
+; NODOT-NEXT:    vsetvli zero, zero, e16, m2, ta, mu
+; NODOT-NEXT:    vwmul.vv v8, v12, v14, v0.t
+; NODOT-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
+; NODOT-NEXT:    vadd.vv v8, v11, v8
+; NODOT-NEXT:    vadd.vv v9, v9, v10
+; NODOT-NEXT:    vadd.vv v8, v9, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: partial_reduce_select:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetvli a0, zero, e8, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vmerge.vvm v10, v10, v9, v0
+; DOT-NEXT:    vsetvli a0, zero, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v9, 0
+; DOT-NEXT:    vqdot.vv v9, v8, v10
+; DOT-NEXT:    vmv.v.v v8, v9
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <vscale x 8 x i8> %a to <vscale x 8 x i32>
   %b.sext = sext <vscale x 8 x i8> %b to <vscale x 8 x i32>

>From 7506ebc2766b7ff82b6c8902e19d46c81d74b1fb Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Thu, 13 Nov 2025 13:53:22 +0000
Subject: [PATCH 2/3] Add freeze

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 22 ++++++++++---------
 1 file changed, 12 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 1cba1d4d5cc22..ff98c0910f4d3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13083,15 +13083,18 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue LHSExtOp = LHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
 
-  // Return 'select(P, Op, splat(0))' if P is nonzero,
-  // or 'P' otherwise.
-  auto tryPredicate = [&](SDValue P, SDValue Op) {
+  // Sets Op = select(P, Op, splat(0)) if P is nonzero, or Op otherwise.
+  // Set ToFreezeOp = freeze(ToFreezeOp) if the value may be poison, to
+  // keep the same semantics.
+  auto ApplyPredicate = [&](SDValue P, SDValue &Op, SDValue &ToFreezeOp) {
     if (!P)
-      return Op;
+      return;
+    if (!DAG.isGuaranteedNotToBePoison(ToFreezeOp))
+      ToFreezeOp = DAG.getFreeze(ToFreezeOp);
     EVT OpVT = Op.getValueType();
     SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
                                           : DAG.getConstant(0, DL, OpVT);
-    return DAG.getSelect(DL, OpVT, P, Op, Zero);
+    Op = DAG.getSelect(DL, OpVT, P, Op, Zero);
   };
 
   // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
@@ -13116,10 +13119,9 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
             TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
       return SDValue();
 
-    SDValue Constant =
-        tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
-    return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
-                       Constant);
+    SDValue C = DAG.getConstant(CTrunc, DL, LHSExtOpVT);
+    ApplyPredicate(Pred, C, LHSExtOp);
+    return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, C);
   }
 
   unsigned RHSOpcode = RHS->getOpcode();
@@ -13160,7 +13162,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
-  RHSExtOp = tryPredicate(Pred, RHSExtOp);
+  ApplyPredicate(Pred, RHSExtOp, LHSExtOp);
   return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
 }
 

>From ccee504ed7b27011c501079079a697c987819efc Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Mon, 17 Nov 2025 08:49:06 +0000
Subject: [PATCH 3/3] Address suggestions

---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h |  4 ++
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 41 ++++++++-----------
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  5 +++
 3 files changed, 26 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index cd466dceb900f..cfc8a4243e894 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1968,6 +1968,10 @@ LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
 /// Build vector implicit truncation is allowed.
 LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
 
+/// Return true if the value is a constant (+/-)0.0 floating-point value or a
+/// splatted vector thereof (with no undefs).
+LLVM_ABI bool isZeroOrZeroSplatFP(SDValue N, bool AllowUndefs = false);
+
 /// Return true if \p V is either a integer or FP constant.
 inline bool isIntOrFPConstant(SDValue V) {
   return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ff98c0910f4d3..59587329493fa 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13039,11 +13039,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
 
   // Handle predication by moving the SELECT into the operand of the MUL.
   SDValue Pred;
-  if (Opc == ISD::VSELECT) {
-    APInt C;
-    if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
-        !C.isZero())
-      return SDValue();
+  if (Opc == ISD::VSELECT && (isZeroOrZeroSplat(Op1->getOperand(2)) ||
+                              isZeroOrZeroSplatFP(Op1->getOperand(2)))) {
     Pred = Op1->getOperand(0);
     Op1 = Op1->getOperand(1);
     Opc = Op1->getOpcode();
@@ -13083,18 +13080,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue LHSExtOp = LHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
 
-  // Sets Op = select(P, Op, splat(0)) if P is nonzero, or Op otherwise.
-  // Set ToFreezeOp = freeze(ToFreezeOp) if the value may be poison, to
-  // keep the same semantics.
-  auto ApplyPredicate = [&](SDValue P, SDValue &Op, SDValue &ToFreezeOp) {
-    if (!P)
-      return;
-    if (!DAG.isGuaranteedNotToBePoison(ToFreezeOp))
-      ToFreezeOp = DAG.getFreeze(ToFreezeOp);
-    EVT OpVT = Op.getValueType();
-    SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
-                                          : DAG.getConstant(0, DL, OpVT);
-    Op = DAG.getSelect(DL, OpVT, P, Op, Zero);
+  // When Pred is non-zero, set Op = select(Pred, Op, splat(0)) and freeze
+  // OtherOp to keep the same semantics when moving the selects into the MUL
+  // operands.
+  auto ApplyPredicate = [&](SDValue &Op, SDValue &OtherOp) {
+    if (Pred) {
+      EVT OpVT = Op.getValueType();
+      SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
+                                            : DAG.getConstant(0, DL, OpVT);
+      Op = DAG.getSelect(DL, OpVT, Pred, Op, Zero);
+      OtherOp = DAG.getFreeze(OtherOp);
+    }
   };
 
   // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
@@ -13120,7 +13116,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
       return SDValue();
 
     SDValue C = DAG.getConstant(CTrunc, DL, LHSExtOpVT);
-    ApplyPredicate(Pred, C, LHSExtOp);
+    ApplyPredicate(C, LHSExtOp);
     return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, C);
   }
 
@@ -13162,7 +13158,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
-  ApplyPredicate(Pred, RHSExtOp, LHSExtOp);
+  ApplyPredicate(RHSExtOp, LHSExtOp);
   return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
 }
 
@@ -13184,11 +13180,8 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
 
   SDValue Pred;
   unsigned Op1Opcode = Op1.getOpcode();
-  if (Op1Opcode == ISD::VSELECT) {
-    APInt C;
-    if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
-        !C.isZero())
-      return SDValue();
+  if (Op1Opcode == ISD::VSELECT && (isZeroOrZeroSplat(Op1->getOperand(2)) ||
+                                    isZeroOrZeroSplatFP(Op1->getOperand(2)))) {
     Pred = Op1->getOperand(0);
     Op1 = Op1->getOperand(1);
     Op1Opcode = Op1->getOpcode();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index c2b4c19846316..16fdef06d6679 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12971,6 +12971,11 @@ bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
   return C && C->isZero();
 }
 
+bool llvm::isZeroOrZeroSplatFP(SDValue N, bool AllowUndefs) {
+  ConstantFPSDNode *C = isConstOrConstSplatFP(N, AllowUndefs);
+  return C && C->isZero();
+}
+
 HandleSDNode::~HandleSDNode() {
   DropOperands();
 }



More information about the llvm-commits mailing list