[llvm] [CodeGen] Implement widening for partial.reduce.add (PR #161834)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 3 05:22:58 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Sander de Smalen (sdesmalen-arm)

<details>
<summary>Changes</summary>

Widening of accumulator/result is done by padding the accumulator with zero elements, performing the partial reduction and then partially reducing the wide vector result (using extract lo/hi + add) into the narrow part of the result vector.

Widening of the input vector is done by padding it with zero elements.

---
Full diff: https://github.com/llvm/llvm-project/pull/161834.diff


3 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+2) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+51) 
- (added) llvm/test/CodeGen/AArch64/partial-reduce-widen.ll (+25) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 586c3411791f9..c4d69aa48434a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -1117,6 +1117,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue WidenVecRes_Unary(SDNode *N);
   SDValue WidenVecRes_InregOp(SDNode *N);
   SDValue WidenVecRes_UnaryOpWithTwoResults(SDNode *N, unsigned ResNo);
+  SDValue WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N);
   void ReplaceOtherWidenResults(SDNode *N, SDNode *WidenNode,
                                 unsigned WidenResNo);
 
@@ -1152,6 +1153,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue WidenVecOp_VP_REDUCE(SDNode *N);
   SDValue WidenVecOp_ExpOp(SDNode *N);
   SDValue WidenVecOp_VP_CttzElements(SDNode *N);
+  SDValue WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N);
 
   /// Helper function to generate a set of operations to perform
   /// a vector operation for a wider type.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 87d5453cd98cf..4b409eb5f4c6c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -5136,6 +5136,10 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
     if (!unrollExpandedOp())
       Res = WidenVecRes_UnaryOpWithTwoResults(N, ResNo);
     break;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Res = WidenVecRes_PARTIAL_REDUCE_MLA(N);
+    break;
   }
   }
 
@@ -6995,6 +6999,34 @@ SDValue DAGTypeLegalizer::WidenVecRes_STRICT_FSETCC(SDNode *N) {
   return DAG.getBuildVector(WidenVT, dl, Scalars);
 }
 
+// Widening the result of a partial reductions is implemented by
+// accumulating into a wider (zero-padded) vector, then incrementally
+// reducing that (extract half vector and add) until it fits
+// the original type.
+SDValue DAGTypeLegalizer::WidenVecRes_PARTIAL_REDUCE_MLA(SDNode *N) {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  EVT WideAccVT = TLI.getTypeToTransformTo(*DAG.getContext(),
+                                           N->getOperand(0).getValueType());
+  SDValue Zero = DAG.getConstant(0, DL, WideAccVT);
+  SDValue MulOp1 = N->getOperand(1);
+  SDValue MulOp2 = N->getOperand(2);
+  SDValue Acc = DAG.getInsertSubvector(DL, Zero, N->getOperand(0), 0);
+  SDValue WidenedRes =
+      DAG.getNode(N->getOpcode(), DL, WideAccVT, Acc, MulOp1, MulOp2);
+  while (ElementCount::isKnownLT(
+      VT.getVectorElementCount(),
+      WidenedRes.getValueType().getVectorElementCount())) {
+    EVT HalfVT =
+        WidenedRes.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());
+    SDValue Lo = DAG.getExtractSubvector(DL, HalfVT, WidenedRes, 0);
+    SDValue Hi = DAG.getExtractSubvector(DL, HalfVT, WidenedRes,
+                                         HalfVT.getVectorMinNumElements());
+    WidenedRes = DAG.getNode(ISD::ADD, DL, HalfVT, Lo, Hi);
+  }
+  return DAG.getInsertSubvector(DL, Zero, WidenedRes, 0);
+}
+
 //===----------------------------------------------------------------------===//
 // Widen Vector Operand
 //===----------------------------------------------------------------------===//
@@ -7127,6 +7159,10 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::VP_REDUCE_FMINIMUM:
     Res = WidenVecOp_VP_REDUCE(N);
     break;
+  case ISD::PARTIAL_REDUCE_UMLA:
+  case ISD::PARTIAL_REDUCE_SMLA:
+    Res = WidenVecOp_PARTIAL_REDUCE_MLA(N);
+    break;
   case ISD::VP_CTTZ_ELTS:
   case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
     Res = WidenVecOp_VP_CttzElements(N);
@@ -8026,6 +8062,21 @@ SDValue DAGTypeLegalizer::WidenVecOp_VP_CttzElements(SDNode *N) {
                      {Source, Mask, N->getOperand(2)}, N->getFlags());
 }
 
+SDValue DAGTypeLegalizer::WidenVecOp_PARTIAL_REDUCE_MLA(SDNode *N) {
+  // Widening of multiplicant operands only. The result and accumulator
+  // should already be legal types.
+  SDLoc DL(N);
+  EVT WideOpVT = TLI.getTypeToTransformTo(*DAG.getContext(),
+                                          N->getOperand(1).getValueType());
+  SDValue Acc = N->getOperand(0);
+  SDValue WidenedOp1 = DAG.getInsertSubvector(
+      DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(1), 0);
+  SDValue WidenedOp2 = DAG.getInsertSubvector(
+      DL, DAG.getConstant(0, DL, WideOpVT), N->getOperand(2), 0);
+  return DAG.getNode(N->getOpcode(), DL, Acc.getValueType(), Acc, WidenedOp1,
+                     WidenedOp2);
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Widening Utilities
 //===----------------------------------------------------------------------===//
diff --git a/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll b/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
new file mode 100644
index 0000000000000..a6b215b610fca
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/partial-reduce-widen.ll
@@ -0,0 +1,25 @@
+; RUN: llc -mattr=+sve,+dotprod < %s | FileCheck %s
+
+define void @partial_reduce_widen_v1i32_acc_v16i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+  %acc = load <1 x i32>, ptr %accptr
+  %vec = load <16 x i32>, ptr %vecptr
+  %partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <16 x i32> %vec)
+  store <1 x i32> %partial.reduce, ptr %resptr
+  ret void
+}
+
+define void @partial_reduce_widen_v3i32_acc_v12i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+  %acc = load <3 x i32>, ptr %accptr
+  %vec = load <12 x i32>, ptr %vecptr
+  %partial.reduce = call <3 x i32> @llvm.vector.partial.reduce.add(<3 x i32> %acc, <12 x i32> %vec)
+  store <3 x i32> %partial.reduce, ptr %resptr
+  ret void
+}
+
+define void @partial_reduce_widen_v4i32_acc_v20i32_vec(ptr %accptr, ptr %resptr, ptr %vecptr) {
+  %acc = load <1 x i32>, ptr %accptr
+  %vec = load <20 x i32>, ptr %vecptr
+  %partial.reduce = call <1 x i32> @llvm.vector.partial.reduce.add(<1 x i32> %acc, <20 x i32> %vec)
+  store <1 x i32> %partial.reduce, ptr %resptr
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/161834


More information about the llvm-commits mailing list