[llvm] [DAGCombiner] Fold select into partial.reduce.add operands. (PR #167857)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 17 06:20:20 PST 2025
https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/167857
>From edbebb4050242885528bd6445990c3b8b7ed4270 Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Wed, 12 Nov 2025 17:12:33 +0000
Subject: [PATCH 1/3] [DAGCombiner] Fold select into partial.reduce.add
operands.
This generates more optimal codegen when using partial reductions
with predication.
partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
-> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
-> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 71 ++++++--
.../partial-reduction-add-predicated.ll | 159 ++++++++++++++++++
llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll | 39 +++--
3 files changed, 242 insertions(+), 27 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index df353c4d91b1a..1cba1d4d5cc22 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13018,22 +13018,37 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
return SDValue();
}
-// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
+// partial_reduce_*mla(acc, mul(*ext(a), *ext(b)), splat(1))
// -> partial_reduce_*mla(acc, a, b)
//
-// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
-// -> partial_reduce_*mla(acc, x, C)
+// partial_reduce_*mla(acc, mul(*ext(x), splat(C)), splat(1))
+// -> partial_reduce_*mla(acc, x, splat(C))
//
-// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
-// -> partial_reduce_fmla(acc, a, b)
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), *ext(b)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), b)
+//
+// partial_reduce_*mla(acc, sel(p, mul(*ext(a), splat(C)), splat(0)), splat(1))
+// -> partial_reduce_*mla(acc, sel(p, a, splat(0)), splat(C))
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
auto *Context = DAG.getContext();
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
-
unsigned Opc = Op1->getOpcode();
+
+ // Handle predication by moving the SELECT into the operand of the MUL.
+ SDValue Pred;
+ if (Opc == ISD::VSELECT) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+ !C.isZero())
+ return SDValue();
+ Pred = Op1->getOperand(0);
+ Op1 = Op1->getOperand(1);
+ Opc = Op1->getOpcode();
+ }
+
if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
return SDValue();
@@ -13068,6 +13083,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
+ // Return 'select(P, Op, splat(0))' if P is nonzero,
+ // or 'P' otherwise.
+ auto tryPredicate = [&](SDValue P, SDValue Op) {
+ if (!P)
+ return Op;
+ EVT OpVT = Op.getValueType();
+ SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
+ : DAG.getConstant(0, DL, OpVT);
+ return DAG.getSelect(DL, OpVT, P, Op, Zero);
+ };
+
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
APInt C;
@@ -13090,8 +13116,10 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ SDValue Constant =
+ tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
- DAG.getConstant(CTrunc, DL, LHSExtOpVT));
+ Constant);
}
unsigned RHSOpcode = RHS->getOpcode();
@@ -13132,17 +13160,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
+ RHSExtOp = tryPredicate(Pred, RHSExtOp);
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
}
-// partial.reduce.umla(acc, zext(op), splat(1))
-// -> partial.reduce.umla(acc, op, splat(trunc(1)))
-// partial.reduce.smla(acc, sext(op), splat(1))
-// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.*mla(acc, *ext(op), splat(1))
+// -> partial.reduce.*mla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
-// partial.reduce.fmla(acc, fpext(op), splat(1.0))
-// -> partial.reduce.fmla(acc, op, splat(1.0))
+//
+// partial.reduce.*mla(acc, sel(p, *ext(op), splat(0)), splat(1))
+// -> partial.reduce.*mla(acc, sel(p, op, splat(0)), splat(trunc(1)))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
@@ -13152,7 +13180,18 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
return SDValue();
+ SDValue Pred;
unsigned Op1Opcode = Op1.getOpcode();
+ if (Op1Opcode == ISD::VSELECT) {
+ APInt C;
+ if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
+ !C.isZero())
+ return SDValue();
+ Pred = Op1->getOperand(0);
+ Op1 = Op1->getOperand(1);
+ Op1Opcode = Op1->getOpcode();
+ }
+
if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
return SDValue();
@@ -13181,6 +13220,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
? DAG.getConstantFP(1, DL, UnextOp1VT)
: DAG.getConstant(1, DL, UnextOp1VT);
+ if (Pred) {
+ SDValue Zero = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? DAG.getConstantFP(0, DL, UnextOp1VT)
+ : DAG.getConstant(0, DL, UnextOp1VT);
+ Constant = DAG.getSelect(DL, UnextOp1VT, Pred, Constant, Zero);
+ }
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
Constant);
}
diff --git a/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
new file mode 100644
index 0000000000000..24cdd0a852222
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduction-add-predicated.ll
@@ -0,0 +1,159 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s | FileCheck %s
+
+target triple = "aarch64"
+
+define <4 x i32> @predicate_dot_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a, <16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v1.16b, v1.16b, #7
+; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %ext.2 = sext <16 x i8> %b to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, %ext.2
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <4 x i32> @predicate_dot_by_C_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v1.16b, v1.16b, #7
+; CHECK-NEXT: movi v3.16b, #127
+; CHECK-NEXT: cmlt v1.16b, v1.16b, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext.1 = sext <16 x i8> %a to <16 x i32>
+ %mul = mul nsw <16 x i32> %ext.1, splat(i32 127)
+ %sel = select <16 x i1> %p, <16 x i32> %mul, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) #0 {
+; CHECK-LABEL: predicate_dot_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: sel z2.b, p0, z2.b, z3.b
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %ext.2 = sext <vscale x 16 x i8> %b to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, %ext.2
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_dot_by_C_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_dot_by_C_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, p0/z, #127 // =0x7f
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext.1 = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %mul = mul nsw <vscale x 16 x i32> %ext.1, splat(i32 127)
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %mul, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x i32> @predicate_ext_mul_fixed_length(<4 x i32> %acc, <16 x i1> %p, <16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.16b, #1
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+ %ext = sext <16 x i8> %a to <16 x i32>
+ %sel = select <16 x i1> %p, <16 x i32> %ext, <16 x i32> zeroinitializer
+ %red = call <4 x i32> @llvm.vector.partial.reduce.add(<4 x i32> %acc, <16 x i32> %sel)
+ ret <4 x i32> %red
+}
+
+define <vscale x 4 x i32> @predicate_ext_mul_scalable(<vscale x 4 x i32> %acc, <vscale x 16 x i1> %p, <vscale x 16 x i8> %a) #0 {
+; CHECK-LABEL: predicate_ext_mul_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, p0/z, #1 // =0x1
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %sel = select <vscale x 16 x i1> %p, <vscale x 16 x i32> %ext, <vscale x 16 x i32> zeroinitializer
+ %red = call <vscale x 4 x i32> @llvm.vector.partial.reduce.add(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %sel)
+ ret <vscale x 4 x i32> %red
+}
+
+define <4 x float> @predicated_fdot_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a, <8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: fdot z0.s, z2.h, z1.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %ext.1 = fpext <8 x half> %a to <8 x float>
+ %ext.2 = fpext <8 x half> %b to <8 x float>
+ %mul = fmul <8 x float> %ext.1, %ext.2
+ %sel = select <8 x i1> %p, <8 x float> %mul, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fdot_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a, <vscale x 8 x half> %b) #1 {
+; CHECK-LABEL: predicated_fdot_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: sel z2.h, p0, z2.h, z3.h
+; CHECK-NEXT: fdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+ %ext.1 = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %ext.2 = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+ %mul = fmul <vscale x 8 x float> %ext.1, %ext.2
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %mul, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+define <4 x float> @predicated_fpext_fmul_fixed_length(<4 x float> %acc, <8 x i1> %p, <8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_fixed_length:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NEXT: movi v3.8h, #60, lsl #8
+; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
+; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
+; CHECK-NEXT: shl v1.8h, v1.8h, #15
+; CHECK-NEXT: cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT: and v1.16b, v1.16b, v3.16b
+; CHECK-NEXT: fdot z0.s, z2.h, z1.h
+; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT: ret
+ %ext = fpext <8 x half> %a to <8 x float>
+ %sel = select <8 x i1> %p, <8 x float> %ext, <8 x float> zeroinitializer
+ %red = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %sel)
+ ret <4 x float> %red
+}
+
+define <vscale x 4 x float> @predicated_fpext_fmul_scalable(<vscale x 4 x float> %acc, <vscale x 8 x i1> %p, <vscale x 8 x half> %a) #1 {
+; CHECK-LABEL: predicated_fpext_fmul_scalable:
+; CHECK: // %bb.0:
+; CHECK-NEXT: movi v2.2d, #0000000000000000
+; CHECK-NEXT: fmov z2.h, p0/m, #1.00000000
+; CHECK-NEXT: fdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+ %ext = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %sel = select <vscale x 8 x i1> %p, <vscale x 8 x float> %ext, <vscale x 8 x float> zeroinitializer
+ %red = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %sel)
+ ret <vscale x 4 x float> %red
+}
+
+attributes #0 = { nounwind "target-features"="+sve,+dotprod" }
+attributes #1 = { nounwind "target-features"="+sve2p1,+dotprod" }
diff --git a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
index 72bf1fa9a8327..d6384a6913efe 100644
--- a/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll
@@ -996,20 +996,31 @@ entry:
}
define <vscale x 2 x i32> @partial_reduce_select(<vscale x 8 x i8> %a, <vscale x 8 x i8> %b, <vscale x 8 x i1> %m) {
-; CHECK-LABEL: partial_reduce_select:
-; CHECK: # %bb.0: # %entry
-; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, ma
-; CHECK-NEXT: vsext.vf2 v12, v8
-; CHECK-NEXT: vsext.vf2 v14, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT: vmv.v.i v8, 0
-; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, mu
-; CHECK-NEXT: vwmul.vv v8, v12, v14, v0.t
-; CHECK-NEXT: vsetvli a0, zero, e32, m1, ta, ma
-; CHECK-NEXT: vadd.vv v8, v11, v8
-; CHECK-NEXT: vadd.vv v9, v9, v10
-; CHECK-NEXT: vadd.vv v8, v9, v8
-; CHECK-NEXT: ret
+; NODOT-LABEL: partial_reduce_select:
+; NODOT: # %bb.0: # %entry
+; NODOT-NEXT: vsetvli a0, zero, e16, m2, ta, ma
+; NODOT-NEXT: vsext.vf2 v12, v8
+; NODOT-NEXT: vsext.vf2 v14, v9
+; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT: vmv.v.i v8, 0
+; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, mu
+; NODOT-NEXT: vwmul.vv v8, v12, v14, v0.t
+; NODOT-NEXT: vsetvli a0, zero, e32, m1, ta, ma
+; NODOT-NEXT: vadd.vv v8, v11, v8
+; NODOT-NEXT: vadd.vv v9, v9, v10
+; NODOT-NEXT: vadd.vv v8, v9, v8
+; NODOT-NEXT: ret
+;
+; DOT-LABEL: partial_reduce_select:
+; DOT: # %bb.0: # %entry
+; DOT-NEXT: vsetvli a0, zero, e8, m1, ta, ma
+; DOT-NEXT: vmv.v.i v10, 0
+; DOT-NEXT: vmerge.vvm v10, v10, v9, v0
+; DOT-NEXT: vsetvli a0, zero, e32, m1, ta, ma
+; DOT-NEXT: vmv.v.i v9, 0
+; DOT-NEXT: vqdot.vv v9, v8, v10
+; DOT-NEXT: vmv.v.v v8, v9
+; DOT-NEXT: ret
entry:
%a.sext = sext <vscale x 8 x i8> %a to <vscale x 8 x i32>
%b.sext = sext <vscale x 8 x i8> %b to <vscale x 8 x i32>
>From 7506ebc2766b7ff82b6c8902e19d46c81d74b1fb Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Thu, 13 Nov 2025 13:53:22 +0000
Subject: [PATCH 2/3] Add freeze
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 22 ++++++++++---------
1 file changed, 12 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 1cba1d4d5cc22..ff98c0910f4d3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13083,15 +13083,18 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
- // Return 'select(P, Op, splat(0))' if P is nonzero,
- // or 'P' otherwise.
- auto tryPredicate = [&](SDValue P, SDValue Op) {
+ // Sets Op = select(P, Op, splat(0)) if P is nonzero, or Op otherwise.
+ // Set ToFreezeOp = freeze(ToFreezeOp) if the value may be poison, to
+ // keep the same semantics.
+ auto ApplyPredicate = [&](SDValue P, SDValue &Op, SDValue &ToFreezeOp) {
if (!P)
- return Op;
+ return;
+ if (!DAG.isGuaranteedNotToBePoison(ToFreezeOp))
+ ToFreezeOp = DAG.getFreeze(ToFreezeOp);
EVT OpVT = Op.getValueType();
SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
: DAG.getConstant(0, DL, OpVT);
- return DAG.getSelect(DL, OpVT, P, Op, Zero);
+ Op = DAG.getSelect(DL, OpVT, P, Op, Zero);
};
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
@@ -13116,10 +13119,9 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
- SDValue Constant =
- tryPredicate(Pred, DAG.getConstant(CTrunc, DL, LHSExtOpVT));
- return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp,
- Constant);
+ SDValue C = DAG.getConstant(CTrunc, DL, LHSExtOpVT);
+ ApplyPredicate(Pred, C, LHSExtOp);
+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, C);
}
unsigned RHSOpcode = RHS->getOpcode();
@@ -13160,7 +13162,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
- RHSExtOp = tryPredicate(Pred, RHSExtOp);
+ ApplyPredicate(Pred, RHSExtOp, LHSExtOp);
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
}
>From ccee504ed7b27011c501079079a697c987819efc Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Mon, 17 Nov 2025 08:49:06 +0000
Subject: [PATCH 3/3] Address suggestions
---
llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 4 ++
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 41 ++++++++-----------
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 5 +++
3 files changed, 26 insertions(+), 24 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index cd466dceb900f..cfc8a4243e894 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1968,6 +1968,10 @@ LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
/// Build vector implicit truncation is allowed.
LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
+/// Return true if the value is a constant (+/-)0.0 floating-point value or a
+/// splatted vector thereof (with no undefs).
+LLVM_ABI bool isZeroOrZeroSplatFP(SDValue N, bool AllowUndefs = false);
+
/// Return true if \p V is either a integer or FP constant.
inline bool isIntOrFPConstant(SDValue V) {
return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index ff98c0910f4d3..59587329493fa 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13039,11 +13039,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
// Handle predication by moving the SELECT into the operand of the MUL.
SDValue Pred;
- if (Opc == ISD::VSELECT) {
- APInt C;
- if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
- !C.isZero())
- return SDValue();
+ if (Opc == ISD::VSELECT && (isZeroOrZeroSplat(Op1->getOperand(2)) ||
+ isZeroOrZeroSplatFP(Op1->getOperand(2)))) {
Pred = Op1->getOperand(0);
Op1 = Op1->getOperand(1);
Opc = Op1->getOpcode();
@@ -13083,18 +13080,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue LHSExtOp = LHS->getOperand(0);
EVT LHSExtOpVT = LHSExtOp.getValueType();
- // Sets Op = select(P, Op, splat(0)) if P is nonzero, or Op otherwise.
- // Set ToFreezeOp = freeze(ToFreezeOp) if the value may be poison, to
- // keep the same semantics.
- auto ApplyPredicate = [&](SDValue P, SDValue &Op, SDValue &ToFreezeOp) {
- if (!P)
- return;
- if (!DAG.isGuaranteedNotToBePoison(ToFreezeOp))
- ToFreezeOp = DAG.getFreeze(ToFreezeOp);
- EVT OpVT = Op.getValueType();
- SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
- : DAG.getConstant(0, DL, OpVT);
- Op = DAG.getSelect(DL, OpVT, P, Op, Zero);
+ // When Pred is non-zero, set Op = select(Pred, Op, splat(0)) and freeze
+ // OtherOp to keep the same semantics when moving the selects into the MUL
+ // operands.
+ auto ApplyPredicate = [&](SDValue &Op, SDValue &OtherOp) {
+ if (Pred) {
+ EVT OpVT = Op.getValueType();
+ SDValue Zero = OpVT.isFloatingPoint() ? DAG.getConstantFP(0.0, DL, OpVT)
+ : DAG.getConstant(0, DL, OpVT);
+ Op = DAG.getSelect(DL, OpVT, Pred, Op, Zero);
+ OtherOp = DAG.getFreeze(OtherOp);
+ }
};
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
@@ -13120,7 +13116,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
return SDValue();
SDValue C = DAG.getConstant(CTrunc, DL, LHSExtOpVT);
- ApplyPredicate(Pred, C, LHSExtOp);
+ ApplyPredicate(C, LHSExtOp);
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, LHSExtOp, C);
}
@@ -13162,7 +13158,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
return SDValue();
- ApplyPredicate(Pred, RHSExtOp, LHSExtOp);
+ ApplyPredicate(RHSExtOp, LHSExtOp);
return DAG.getNode(NewOpc, DL, N->getValueType(0), Acc, LHSExtOp, RHSExtOp);
}
@@ -13184,11 +13180,8 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDValue Pred;
unsigned Op1Opcode = Op1.getOpcode();
- if (Op1Opcode == ISD::VSELECT) {
- APInt C;
- if (!ISD::isConstantSplatVector(Op1->getOperand(2).getNode(), C) ||
- !C.isZero())
- return SDValue();
+ if (Op1Opcode == ISD::VSELECT && (isZeroOrZeroSplat(Op1->getOperand(2)) ||
+ isZeroOrZeroSplatFP(Op1->getOperand(2)))) {
Pred = Op1->getOperand(0);
Op1 = Op1->getOperand(1);
Op1Opcode = Op1->getOpcode();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index c2b4c19846316..16fdef06d6679 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12971,6 +12971,11 @@ bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
return C && C->isZero();
}
+bool llvm::isZeroOrZeroSplatFP(SDValue N, bool AllowUndefs) {
+ ConstantFPSDNode *C = isConstOrConstSplatFP(N, AllowUndefs);
+ return C && C->isZero();
+}
+
HandleSDNode::~HandleSDNode() {
DropOperands();
}
More information about the llvm-commits
mailing list