[llvm] [CodeGen] Implement widening for partial.reduce.add (PR #161834)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 3 06:06:11 PDT 2025
https://github.com/sdesmalen-arm updated https://github.com/llvm/llvm-project/pull/161834
>From d74ba76a8594eb85e2ed1674267691fcbbae598b Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 3 Oct 2025 11:36:33 +0100
Subject: [PATCH 1/2] [CodeGen] Implement widening for partial.reduce.add
Widening of accumulator/result is done by padding the accumulator with zero
elements, performing the partial reduction and then partially reducing the
wide vector result (using extract lo/hi + add) into the narrow part of the
result vector.
Widening of the input vector is done by padding it with zero elements.
---
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h | 2 +
.../SelectionDAG/LegalizeVectorTypes.cpp | 51 +++++++++++++++++++
.../CodeGen/AArch64/partial-reduce-widen.ll | 25 +++++++++
3 files changed, 78 insertions(+)
create mode 100644 llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 586c3411791f9..c4d69aa48434a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -1117,6 +1117,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecRes_Unary(SDNode *N);
SDValue WidenVecRes_InregOp(SDNode *N);
SDValue WidenVecRes_UnaryOpWithTwoResults(SDNode *N, unsigned ResNo);
+ SDValue WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N);
void ReplaceOtherWidenResults(SDNode *N, SDNode *WidenNode,
unsigned WidenResNo);
@@ -1152,6 +1153,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecOp_VP_REDUCE(SDNode *N);
SDValue WidenVecOp_ExpOp(SDNode *N);
SDValue WidenVecOp_VP_CttzElements(SDNode *N);
+ SDValue WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N);
/// Helper function to generate a set of operations to perform
/// a vector operation for a wider type.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 87d5453cd98cf..4b409eb5f4c6c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -5136,6 +5136,10 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
if (!unrollExpandedOp())
Res = WidenVecRes_UnaryOpWithTwoResults(N, ResNo);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Res = WidenVecRes_PARTIAL_REDUCE_MLA(N);
+ break;
}
}
@@ -6995,6 +6999,34 @@ SDValue DAGTypeLegalizer::WidenVecRes_STRICT_FSETCC(SDNode *N) {
return DAG.getBuildVector(WidenVT, dl, Scalars);
}
+// Widening the result of a partial reductions is implemented by
+// accumulating into a wider (zero-padded) vector, then incrementally
+// reducing that (extract half vector and add) until it fits
+// the original type.
+SDValue DAGTypeLegalizer::WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ EVT WideAccVT = TLI.getTypeToTransformTo(*DAG.getContext(),
+ N->getOperand(0).getValueType());
+ SDValue Zero = DAG.getConstant(0, DL, WideAccVT);
+ SDValue MulOp1 = N->getOperand(1);
+ SDValue MulOp2 = N->getOperand(2);
+ SDValue Acc = DAG.getInsertSubvector(DL, Zero, N->getOperand(0), 0);
+ SDValue WidenedRes =
+ DAG.getNode(N->getOpcode(), DL, WideAccVT, Acc, MulOp1, MulOp2);
+ while (ElementCount::isKnownLT(
+ VT.getVectorElementCount(),
+ WidenedRes.getValueType().getVectorElementCount())) {
+ EVT HalfVT =
+ WidenedRes.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());
+ SDValue Lo = DAG.getExtractSubvector(DL, HalfVT, WidenedRes, 0);
+ SDValue Hi = DAG.getExtractSubvector(DL, HalfVT, WidenedRes,
+ HalfVT.getVectorMinNumElements());
+ WidenedRes = DAG.getNode(ISD::ADD, DL, HalfVT, Lo, Hi);
+ }
+ return DAG.getInsertSubvector(DL, Zero, WidenedRes, 0);
+}
+
//===----------------------------------------------------------------------===//
// Widen Vector Operand
//===----------------------------------------------------------------------===//
@@ -7127,6 +7159,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VP_REDUCE_FMINIMUM:
Res = WidenVecOp_VP_REDUCE(N);
break;
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ Res = WidenVecOp_PARTIAL_REDUCE_MLA(N);
+ break;
case ISD::VP_CTTZ_ELTS:
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
Res = WidenVecOp_VP_CttzElements(N);
@@ -8026,6 +8062,21 @@ SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) {
{Source, Mask, N->getOperand(2)}, N->getFlags());
}
+SDValue DAGTypeLegalizer::WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
+ // Widening of multiplicant operands only. The result and accumulator
+ // should already be legal types.
+ SDLoc DL(N);
+ EVT WideOpVT = TLI.getTypeToTransformTo(*DAG.getContext(),
+ N->getOperand(1).getValueType());
+ SDValue Acc = N->getOperand(0);
+ SDValue WidenedOp1 = DAG.getInsertSubvector(
+ DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(1), 0);
+ SDValue WidenedOp2 = DAG.getInsertSubvector(
+ DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(2), 0);
+ return DAG.getNode(N->getOpcode(), DL, Acc.getValueType(), Acc, WidenedOp1,
+ WidenedOp2);
+}
+
//===----------------------------------------------------------------------===//
// Vector Widening Utilities
//===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll b/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
new file mode 100644
index 0000000000000..a6b215b610fca
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
@@ -0,0 +1,25 @@
+; RUN: llc -mattr=+sve,+dotprod < %s | FileCheck %s
+
+define void @partial_reduce_widen_v1i32_acc_v16i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+ %acc = load <1 x i32>, ptr %accptr
+ %vec = load <16 x i32>, ptr %vecptr
+ %partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <16 x i32> %vec)
+ store <1 x i32> %partial.reduce, ptr %resptr
+ ret void
+}
+
+define void @partial_reduce_widen_v3i32_acc_v12i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+ %acc = load <3 x i32>, ptr %accptr
+ %vec = load <12 x i32>, ptr %vecptr
+ %partial.reduce = call <3 x i32> @llvm.vector.partial.reduce.add(<3 x i32> %acc, <12 x i32> %vec)
+ store <3 x i32> %partial.reduce, ptr %resptr
+ ret void
+}
+
+define void @partial_reduce_widen_v4i32_acc_v20i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+ %acc = load <1 x i32>, ptr %accptr
+ %vec = load <20 x i32>, ptr %vecptr
+ %partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <20 x i32> %vec)
+ store <1 x i32> %partial.reduce, ptr %resptr
+ ret void
+}
>From 18f192e6d29327cb99d13555f36c9c9c5166abfb Mon Sep 17 00:00:00 2001
From: Sander de Smalen <sander.desmalen at arm.com>
Date: Fri, 3 Oct 2025 14:04:45 +0100
Subject: [PATCH 2/2] Add missing CHECK lines, also use poison iso zero, for
widened result lanes
---
.../SelectionDAG/LegalizeVectorTypes.cpp | 2 +-
.../CodeGen/AArch64/partial-reduce-widen.ll | 75 ++++++++++++++++++-
2 files changed, 75 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 4b409eb5f4c6c..ba3cbf5d4083b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -7024,7 +7024,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
HalfVT.getVectorMinNumElements());
WidenedRes = DAG.getNode(ISD::ADD, DL, HalfVT, Lo, Hi);
}
- return DAG.getInsertSubvector(DL, Zero, WidenedRes, 0);
+ return DAG.getInsertSubvector(DL, DAG.getPOISON(WideAccVT), WidenedRes, 0);
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll b/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
index a6b215b610fca..61cb149727c06 100644
--- a/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
@@ -1,6 +1,30 @@
-; RUN: llc -mattr=+sve,+dotprod < %s | FileCheck %s
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s | FileCheck %s
+
+target triple = "aarch64"
define void @partial_reduce_widen_v1i32_acc_v16i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+; CHECK-LABEL: partial_reduce_widen_v1i32_acc_v16i32_vec:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ldp q1, q0, [x2]
+; CHECK-NEXT: ldr s2, [x0]
+; CHECK-NEXT: ldp q5, q6, [x2, #32]
+; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8
+; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8
+; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
+; CHECK-NEXT: ext v2.16b, v5.16b, v5.16b, #8
+; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: add v1.2s, v4.2s, v3.2s
+; CHECK-NEXT: ext v3.16b, v6.16b, v6.16b, #8
+; CHECK-NEXT: add v0.2s, v0.2s, v5.2s
+; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
+; CHECK-NEXT: add v0.2s, v0.2s, v6.2s
+; CHECK-NEXT: add v1.2s, v3.2s, v1.2s
+; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: dup v1.2s, v0.s[1]
+; CHECK-NEXT: add v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: str s0, [x1]
+; CHECK-NEXT: ret
%acc = load <1 x i32>, ptr %accptr
%vec = load <16 x i32>, ptr %vecptr
%partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <16 x i32> %vec)
@@ -9,6 +33,24 @@ define void @partial_reduce_widen_v1i32_acc_v16i32_vec(ptr %accptr, ptr %resptr,
}
define void @partial_reduce_widen_v3i32_acc_v12i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+; CHECK-LABEL: partial_reduce_widen_v3i32_acc_v12i32_vec:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sub sp, sp, #128
+; CHECK-NEXT: .cfi_def_cfa_offset 128
+; CHECK-NEXT: ldp q1, q0, [x2]
+; CHECK-NEXT: ldr q2, [x0]
+; CHECK-NEXT: mov v2.s[3], wzr
+; CHECK-NEXT: add v0.4s, v1.4s, v0.4s
+; CHECK-NEXT: ldr q1, [x2, #32]
+; CHECK-NEXT: add v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #8
+; CHECK-NEXT: add v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: mov s1, v0.s[2]
+; CHECK-NEXT: str d0, [x1]
+; CHECK-NEXT: str s1, [x1, #8]
+; CHECK-NEXT: add sp, sp, #128
+; CHECK-NEXT: ret
%acc = load <3 x i32>, ptr %accptr
%vec = load <12 x i32>, ptr %vecptr
%partial.reduce = call <3 x i32> @llvm.vector.partial.reduce.add(<3 x i32> %acc, <12 x i32> %vec)
@@ -17,6 +59,37 @@ define void @partial_reduce_widen_v3i32_acc_v12i32_vec(ptr %accptr, ptr %resptr,
}
define void @partial_reduce_widen_v4i32_acc_v20i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+; CHECK-LABEL: partial_reduce_widen_v4i32_acc_v20i32_vec:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sub sp, sp, #272
+; CHECK-NEXT: str x29, [sp, #256] // 8-byte Folded Spill
+; CHECK-NEXT: .cfi_def_cfa_offset 272
+; CHECK-NEXT: .cfi_offset w29, -16
+; CHECK-NEXT: ldp q1, q0, [x2]
+; CHECK-NEXT: ldr s2, [x0]
+; CHECK-NEXT: ldp q5, q6, [x2, #32]
+; CHECK-NEXT: ldr x29, [sp, #256] // 8-byte Folded Reload
+; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8
+; CHECK-NEXT: ext v4.16b, v0.16b, v0.16b, #8
+; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
+; CHECK-NEXT: ext v2.16b, v5.16b, v5.16b, #8
+; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: add v1.2s, v4.2s, v3.2s
+; CHECK-NEXT: ext v3.16b, v6.16b, v6.16b, #8
+; CHECK-NEXT: ldr q4, [x2, #64]
+; CHECK-NEXT: add v0.2s, v0.2s, v5.2s
+; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
+; CHECK-NEXT: ext v2.16b, v4.16b, v4.16b, #8
+; CHECK-NEXT: add v0.2s, v0.2s, v6.2s
+; CHECK-NEXT: add v1.2s, v3.2s, v1.2s
+; CHECK-NEXT: add v0.2s, v0.2s, v4.2s
+; CHECK-NEXT: add v1.2s, v2.2s, v1.2s
+; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: dup v1.2s, v0.s[1]
+; CHECK-NEXT: add v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: str s0, [x1]
+; CHECK-NEXT: add sp, sp, #272
+; CHECK-NEXT: ret
%acc = load <1 x i32>, ptr %accptr
%vec = load <20 x i32>, ptr %vecptr
%partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <20 x i32> %vec)
More information about the llvm-commits
mailing list