[llvm] Add `llvm.vector.partial.reduce.fadd` intrinsic (PR #159776)

Damian Heaton via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 23 02:17:10 PDT 2025


https://github.com/dheaton-arm updated https://github.com/llvm/llvm-project/pull/159776

>From e60db23ef3e7ca8df80ade4ace4772b18dbd85d3 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Fri, 19 Sep 2025 10:51:26 +0000
Subject: [PATCH 1/4] Add `llvm.vector.partial.reduction.fadd` intrinsic

With this intrinsic, and supporting SelectionDAG nodes, we can better make use of instructions such as AArch64's `FDOT`.
---
 llvm/docs/LangRef.rst                         | 42 ++++++++++++
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |  3 +-
 llvm/include/llvm/CodeGen/TargetLowering.h    |  4 +-
 llvm/include/llvm/IR/Intrinsics.td            |  4 ++
 .../include/llvm/Target/TargetSelectionDAG.td |  2 +
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 22 +++++--
 .../SelectionDAG/LegalizeVectorOps.cpp        |  2 +
 .../SelectionDAG/LegalizeVectorTypes.cpp      |  2 +
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  3 +-
 .../SelectionDAG/SelectionDAGBuilder.cpp      | 13 ++++
 .../SelectionDAG/SelectionDAGDumper.cpp       |  2 +
 .../CodeGen/SelectionDAG/TargetLowering.cpp   | 21 +++---
 .../Target/AArch64/AArch64ISelLowering.cpp    |  9 ++-
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  3 +
 llvm/test/CodeGen/AArch64/sve2p1-fdot.ll      | 66 +++++++++++++++++++
 15 files changed, 179 insertions(+), 19 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sve2p1-fdot.ll

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index e6713c827d6ab..ea9bf43591f41 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20614,6 +20614,48 @@ performance, and an out-of-loop phase to calculate the final scalar result.
 By avoiding the introduction of new ordering constraints, these intrinsics
 enhance the ability to leverage a target's accumulation instructions.
 
+'``llvm.vector.partial.reduce.fadd.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <4 x f32> @llvm.vector.partial.reduce.fadd.v4f32.v4f32.v8f32(<4 x f32> %a, <8 x f32> %b)
+      declare <vscale x 4 x f32> @llvm.vector.partial.reduce.add.nxv4f32.nxv4f32.nxv8f32(<vscale x 4 x f32> %a, <vscale x 8 x f32> %b)
+
+Overview:
+"""""""""
+
+The '``llvm.vector.partial.reduce.fadd.*``' intrinsics reduce the
+concatenation of the two vector arguments down to the number of elements of the
+result vector type.
+
+Arguments:
+""""""""""
+
+The first argument is a floating-point vector with the same type as the result.
+
+The second argument is a vector with a length that is a known integer multiple
+of the result's type, while maintaining the same element type.
+
+Semantics:
+""""""""""
+
+Other than the reduction operator (e.g. add) the way in which the concatenated
+arguments is reduced is entirely unspecified. By their nature these intrinsics
+are not expected to be useful in isolation but instead implement the first phase
+of an overall reduction operation.
+
+The typical use case is loop vectorization where reductions are split into an
+in-loop phase, where maintaining an unordered vector result is important for
+performance, and an out-of-loop phase to calculate the final scalar result.
+
+By avoiding the introduction of new ordering constraints, these intrinsics
+enhance the ability to leverage a target's accumulation instructions.
+
 '``llvm.experimental.vector.histogram.*``' Intrinsic
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index c76c83d84b3c7..83ee6ff677e3d 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1510,6 +1510,7 @@ enum NodeType {
   PARTIAL_REDUCE_SMLA,  // sext, sext
   PARTIAL_REDUCE_UMLA,  // zext, zext
   PARTIAL_REDUCE_SUMLA, // sext, zext
+  PARTIAL_REDUCE_FMLA,  // fpext, fpext
 
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
@@ -1761,7 +1762,7 @@ LLVM_ABI CondCode getSetCCInverse(CondCode Operation, EVT Type);
 
 inline bool isExtOpcode(unsigned Opcode) {
   return Opcode == ISD::ANY_EXTEND || Opcode == ISD::ZERO_EXTEND ||
-         Opcode == ISD::SIGN_EXTEND;
+         Opcode == ISD::SIGN_EXTEND || Opcode == ISD::FP_EXTEND;
 }
 
 inline bool isExtVecInRegOpcode(unsigned Opcode) {
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 4c2d991308d30..09b699ff9ae49 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1672,7 +1672,7 @@ class LLVM_ABI TargetLoweringBase {
   LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
                                            EVT InputVT) const {
     assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
-           Opc == ISD::PARTIAL_REDUCE_SUMLA);
+           Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
     PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
                                     InputVT.getSimpleVT().SimpleTy};
     auto It = PartialReduceMLAActions.find(Key);
@@ -2774,7 +2774,7 @@ class LLVM_ABI TargetLoweringBase {
   void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
                                  LegalizeAction Action) {
     assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
-           Opc == ISD::PARTIAL_REDUCE_SUMLA);
+           Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
     assert(AccVT.isValid() && InputVT.isValid() &&
            "setPartialReduceMLAAction types aren't valid");
     PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 585371a6a4423..1ecfe284e05fa 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2801,6 +2801,10 @@ def int_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
                                                           [llvm_anyvector_ty, llvm_anyvector_ty],
                                                           [IntrNoMem]>;
 
+def int_vector_partial_reduce_fadd : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                                                                        [llvm_anyfloat_ty, llvm_anyfloat_ty],
+                                                                        [IntrNoMem]>;
+
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 5e57dcaa303f3..1f0800885937e 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -522,6 +522,8 @@ def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
                                  SDTPartialReduceMLA>;
 def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
                                  SDTPartialReduceMLA>;
+def partial_reduce_fmla : SDNode<"ISD::PARTIAL_REDUCE_FMLA",
+                                 SDTPartialReduceMLA>;
 
 def fadd       : SDNode<"ISD::FADD"       , SDTFPBinOp, [SDNPCommutative]>;
 def fsub       : SDNode<"ISD::FSUB"       , SDTFPBinOp>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 0c773e7dcb5de..94114a992fc0c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2040,6 +2040,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
                                 return visitPARTIAL_REDUCE_MLA(N);
   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
@@ -12979,8 +12980,11 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue Op2 = N->getOperand(2);
 
   APInt C;
-  if (Op1->getOpcode() != ISD::MUL ||
-      !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
+  if (!(Op1->getOpcode() == ISD::MUL &&
+        ISD::isConstantSplatVector(Op2.getNode(), C) && C.isOne()) &&
+      !(Op1->getOpcode() == ISD::FMUL &&
+        ISD::isConstantSplatVector(Op2.getNode(), C) &&
+        C == APFloat(1.0f).bitcastToAPInt().trunc(C.getBitWidth())))
     return SDValue();
 
   SDValue LHS = Op1->getOperand(0);
@@ -13035,6 +13039,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
     NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
     std::swap(LHSExtOp, RHSExtOp);
+  } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
+    NewOpc = ISD::PARTIAL_REDUCE_FMLA;
   } else
     return SDValue();
   // For a 2-stage extend the signedness of both of the extends must match
@@ -13070,22 +13076,26 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
 
   APInt ConstantOne;
   if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !(ConstantOne.isOne() ||
+        ConstantOne ==
+            APFloat(1.0f).bitcastToAPInt().trunc(ConstantOne.getBitWidth())))
     return SDValue();
 
   unsigned Op1Opcode = Op1.getOpcode();
   if (!ISD::isExtOpcode(Op1Opcode))
     return SDValue();
 
-  bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+  bool Op1IsSigned = Op1Opcode != ISD::ZERO_EXTEND;
   bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
   if (Op1IsSigned != NodeIsSigned &&
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  unsigned NewOpcode =
-      Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+  unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+                           ? ISD::PARTIAL_REDUCE_FMLA
+                       : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
+                                     : ISD::PARTIAL_REDUCE_UMLA;
 
   SDValue UnextOp1 = Op1.getOperand(0);
   EVT UnextOp1VT = UnextOp1.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8e423c4f83b38..94751be5b7986 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -534,6 +534,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Action =
         TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
                                       Node->getOperand(1).getValueType());
@@ -1243,6 +1244,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
     return;
   case ISD::VECREDUCE_SEQ_FADD:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index ff7cd665446cc..e6f19499b8f41 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1459,6 +1459,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
     break;
   case ISD::GET_ACTIVE_LANE_MASK:
@@ -3674,6 +3675,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
     break;
   }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2dab927d2648c..e7f8e3487f0c4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8375,7 +8375,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   }
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
-  case ISD::PARTIAL_REDUCE_SUMLA: {
+  case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA: {
     [[maybe_unused]] EVT AccVT = N1.getValueType();
     [[maybe_unused]] EVT Input1VT = N2.getValueType();
     [[maybe_unused]] EVT Input2VT = N3.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index b5201a311c591..544b4abd6b45a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8114,6 +8114,19 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
                          Input, DAG.getConstant(1, sdl, Input.getValueType())));
     return;
   }
+  case Intrinsic::vector_partial_reduce_fadd: {
+    if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
+      visitTargetIntrinsic(I, Intrinsic);
+      return;
+    }
+    SDValue Acc = getValue(I.getOperand(0));
+    SDValue Input = getValue(I.getOperand(1));
+    setValue(&I,
+             DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc,
+                         Input,
+                         DAG.getConstantFP(1.0f, sdl, Input.getValueType())));
+    return;
+  }
   case Intrinsic::experimental_cttz_elts: {
     auto DL = getCurSDLoc();
     SDValue Op = getValue(I.getOperand(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 4b2a00c2e2cfa..cf5c269c20761 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -587,6 +587,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
     return "partial_reduce_smla";
   case ISD::PARTIAL_REDUCE_SUMLA:
     return "partial_reduce_sumla";
+  case ISD::PARTIAL_REDUCE_FMLA:
+    return "partial_reduce_fmla";
   case ISD::LOOP_DEPENDENCE_WAR_MASK:
     return "loop_dep_war";
   case ISD::LOOP_DEPENDENCE_RAW_MASK:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index e73f82f3786b8..943f292d67fc8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12061,12 +12061,14 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
       EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
                        MulOpVT.getVectorElementCount());
 
-  unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
-                      ? ISD::ZERO_EXTEND
-                      : ISD::SIGN_EXTEND;
-  unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
-                      ? ISD::SIGN_EXTEND
-                      : ISD::ZERO_EXTEND;
+  unsigned ExtOpcLHS =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA   ? ISD::FP_EXTEND
+      : N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA ? ISD::ZERO_EXTEND
+                                                   : ISD::SIGN_EXTEND;
+  unsigned ExtOpcRHS =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA   ? ISD::FP_EXTEND
+      : N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ? ISD::SIGN_EXTEND
+                                                   : ISD::ZERO_EXTEND;
 
   if (ExtMulOpVT != MulOpVT) {
     MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
@@ -12075,7 +12077,7 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   SDValue Input = MulLHS;
   APInt ConstantOne;
   if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !(ConstantOne.isOne() || ConstantOne == APFloat(1.0f).bitcastToAPInt()))
     Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
 
   unsigned Stride = AccVT.getVectorMinNumElements();
@@ -12086,10 +12088,13 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   for (unsigned I = 0; I < ScaleFactor; I++)
     Subvectors.push_back(DAG.getExtractSubvector(DL, AccVT, Input, I * Stride));
 
+  unsigned FlatNode =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FADD : ISD::ADD;
+
   // Flatten the subvector tree
   while (Subvectors.size() > 1) {
     Subvectors.push_back(
-        DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]}));
+        DAG.getNode(FlatNode, DL, AccVT, {Subvectors[0], Subvectors[1]}));
     Subvectors.pop_front();
     Subvectors.pop_front();
   }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index cd7f0e719ad0c..9f14f8318cb18 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1919,6 +1919,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     }
   }
 
+  // Handle floating-point partial reduction
+  if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
+    static const unsigned FMLAOps[] = {ISD::PARTIAL_REDUCE_FMLA};
+    setPartialReduceMLAAction(FMLAOps, MVT::nxv4f32, MVT::nxv8f16, Legal);
+  }
+
   // Handle non-aliasing elements mask
   if (Subtarget->hasSVE2() ||
       (Subtarget->hasSME() && Subtarget->isStreaming())) {
@@ -2184,7 +2190,8 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
 
 bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     const IntrinsicInst *I) const {
-  assert(I->getIntrinsicID() == Intrinsic::vector_partial_reduce_add &&
+  assert((I->getIntrinsicID() == Intrinsic::vector_partial_reduce_add ||
+          I->getIntrinsicID() == Intrinsic::vector_partial_reduce_fadd) &&
          "Unexpected intrinsic!");
   return true;
 }
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 7fe4f7acdbd49..8ef69ad13abc5 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4228,6 +4228,9 @@ defm FCLAMP_ZZZ : sve_fp_clamp<"fclamp", AArch64fclamp>;
 defm FDOT_ZZZ_S  : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, int_aarch64_sve_fdot_x2>;
 defm FDOT_ZZZI_S : sve_float_dot_indexed<0b0, 0b00, ZPR16, ZPR3b16, "fdot", nxv8f16, int_aarch64_sve_fdot_lane_x2>;
 
+def : Pat<(nxv4f32 (partial_reduce_fmla nxv4f32:$Acc, nxv8f16:$LHS, nxv8f16:$RHS)),
+          (FDOT_ZZZ_S $Acc, $LHS, $RHS)>;
+
 defm BFMLSLB_ZZZ_S : sve2_fp_mla_long<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb>;
 defm BFMLSLT_ZZZ_S : sve2_fp_mla_long<0b111, "bfmlslt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslt>;
 defm BFMLSLB_ZZZI_S : sve2_fp_mla_long_by_indexed_elem<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb_lane>;
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
new file mode 100644
index 0000000000000..5bb1fae43392f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
+
+define <vscale x 4 x float> @fdot_wide_vl128(<vscale x 4 x float> %acc, <vscale x 8 x half> %a, <vscale x 8 x half> %b) {
+; CHECK-LABEL: fdot_wide_vl128:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+  %b.wide = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+  %mult = fmul <vscale x 8 x float> %a.wide, %b.wide
+  %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %mult)
+  ret <vscale x 4 x float> %partial.reduce
+}
+
+define void @fdot_wide_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
+; CHECK-LABEL: fdot_wide_vl256:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    ld1h { z0.s }, p0/z, [x1]
+; CHECK-NEXT:    ld1h { z1.s }, p0/z, [x2]
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x1, #1, mul vl]
+; CHECK-NEXT:    ld1h { z3.s }, p0/z, [x2, #1, mul vl]
+; CHECK-NEXT:    fcvt z0.s, p0/m, z0.h
+; CHECK-NEXT:    fcvt z1.s, p0/m, z1.h
+; CHECK-NEXT:    fcvt z2.s, p0/m, z2.h
+; CHECK-NEXT:    fcvt z3.s, p0/m, z3.h
+; CHECK-NEXT:    fmul z0.s, z0.s, z1.s
+; CHECK-NEXT:    ldr z1, [x0]
+; CHECK-NEXT:    fmul z2.s, z2.s, z3.s
+; CHECK-NEXT:    fadd z0.s, z1.s, z0.s
+; CHECK-NEXT:    fadd z0.s, z0.s, z2.s
+; CHECK-NEXT:    str z0, [x0]
+; CHECK-NEXT:    ret
+entry:
+  %acc = load <8 x float>, ptr %accptr
+  %a = load <16 x half>, ptr %aptr
+  %b = load <16 x half>, ptr %bptr
+  %a.wide = fpext <16 x half> %a to <16 x float>
+  %b.wide = fpext <16 x half> %b to <16 x float>
+  %mult = fmul <16 x float> %a.wide, %b.wide
+  %partial.reduce = call <8 x float> @llvm.vector.partial.reduce.fadd(<8 x float> %acc, <16 x float> %mult)
+  store <8 x float> %partial.reduce, ptr %accptr
+  ret void
+}
+
+define <4 x float> @fixed_fdot_wide(<4 x float> %acc, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: fixed_fdot_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fcvtl v3.4s, v1.4h
+; CHECK-NEXT:    fcvtl v4.4s, v2.4h
+; CHECK-NEXT:    fcvtl2 v1.4s, v1.8h
+; CHECK-NEXT:    fcvtl2 v2.4s, v2.8h
+; CHECK-NEXT:    fmul v3.4s, v3.4s, v4.4s
+; CHECK-NEXT:    fmul v1.4s, v1.4s, v2.4s
+; CHECK-NEXT:    fadd v0.4s, v0.4s, v3.4s
+; CHECK-NEXT:    fadd v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = fpext <8 x half> %a to <8 x float>
+  %b.wide = fpext <8 x half> %b to <8 x float>
+  %mult = fmul <8 x float> %a.wide, %b.wide
+  %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult)
+  ret <4 x float> %partial.reduce
+}

>From 8eed437ca219ad40c9971f1ad5409083d3ac5b36 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Fri, 19 Sep 2025 16:00:49 +0000
Subject: [PATCH 2/4] Revert adding `FP_EXTEND` to `isExtOpcode`

---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        | 2 +-
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 83ee6ff677e3d..e1f6aab0040db 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1762,7 +1762,7 @@ LLVM_ABI CondCode getSetCCInverse(CondCode Operation, EVT Type);
 
 inline bool isExtOpcode(unsigned Opcode) {
   return Opcode == ISD::ANY_EXTEND || Opcode == ISD::ZERO_EXTEND ||
-         Opcode == ISD::SIGN_EXTEND || Opcode == ISD::FP_EXTEND;
+         Opcode == ISD::SIGN_EXTEND;
 }
 
 inline bool isExtVecInRegOpcode(unsigned Opcode) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 94114a992fc0c..a5bc3a1779dc1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12990,7 +12990,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue LHS = Op1->getOperand(0);
   SDValue RHS = Op1->getOperand(1);
   unsigned LHSOpcode = LHS->getOpcode();
-  if (!ISD::isExtOpcode(LHSOpcode))
+  if (!ISD::isExtOpcode(LHSOpcode) && LHSOpcode != ISD::FP_EXTEND)
     return SDValue();
 
   SDValue LHSExtOp = LHS->getOperand(0);
@@ -13022,7 +13022,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   }
 
   unsigned RHSOpcode = RHS->getOpcode();
-  if (!ISD::isExtOpcode(RHSOpcode))
+  if (!ISD::isExtOpcode(RHSOpcode) && RHSOpcode != ISD::FP_EXTEND)
     return SDValue();
 
   SDValue RHSExtOp = RHS->getOperand(0);

>From 081fac1fa1acf2b5380d9d7c8fac691017dd9bb3 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Mon, 22 Sep 2025 10:15:46 +0000
Subject: [PATCH 3/4] Address review comments

Corrected LangRef typos, improved const
comparisons for fadd, and add direct tests.
---
 llvm/docs/LangRef.rst                         |  6 +-
 llvm/include/llvm/IR/Intrinsics.td            |  4 +-
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 20 +++---
 .../CodeGen/SelectionDAG/TargetLowering.cpp   |  8 ++-
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  3 -
 llvm/lib/Target/AArch64/SVEInstrFormats.td    |  1 +
 llvm/test/CodeGen/AArch64/sve2p1-fdot.ll      | 66 +++++++++++++++++++
 7 files changed, 90 insertions(+), 18 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index ea9bf43591f41..6993ad170bce7 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20623,8 +20623,8 @@ This is an overloaded intrinsic.
 
 ::
 
-      declare <4 x f32> @llvm.vector.partial.reduce.fadd.v4f32.v4f32.v8f32(<4 x f32> %a, <8 x f32> %b)
-      declare <vscale x 4 x f32> @llvm.vector.partial.reduce.add.nxv4f32.nxv4f32.nxv8f32(<vscale x 4 x f32> %a, <vscale x 8 x f32> %b)
+      declare <4 x f32> @llvm.vector.partial.reduce.fadd.v4f32.v8f32(<4 x f32> %a, <8 x f32> %b)
+      declare <vscale x 4 x f32> @llvm.vector.partial.reduce.fadd.nxv4f32.nxv8f32(<vscale x 4 x f32> %a, <vscale x 8 x f32> %b)
 
 Overview:
 """""""""
@@ -20644,7 +20644,7 @@ of the result's type, while maintaining the same element type.
 Semantics:
 """"""""""
 
-Other than the reduction operator (e.g. add) the way in which the concatenated
+Other than the reduction operator (e.g. fadd) the way in which the concatenated
 arguments is reduced is entirely unspecified. By their nature these intrinsics
 are not expected to be useful in isolation but instead implement the first phase
 of an overall reduction operation.
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 1ecfe284e05fa..28dbaffbcb784 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2802,8 +2802,8 @@ def int_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
                                                           [IntrNoMem]>;
 
 def int_vector_partial_reduce_fadd : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
-                                                                        [llvm_anyfloat_ty, llvm_anyfloat_ty],
-                                                                        [IntrNoMem]>;
+                                                           [llvm_anyfloat_ty, llvm_anyfloat_ty],
+                                                           [IntrNoMem]>;
 
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a5bc3a1779dc1..f3ada2a7cdc76 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12980,11 +12980,12 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue Op2 = N->getOperand(2);
 
   APInt C;
+  ConstantFPSDNode *CFP;
   if (!(Op1->getOpcode() == ISD::MUL &&
         ISD::isConstantSplatVector(Op2.getNode(), C) && C.isOne()) &&
       !(Op1->getOpcode() == ISD::FMUL &&
-        ISD::isConstantSplatVector(Op2.getNode(), C) &&
-        C == APFloat(1.0f).bitcastToAPInt().trunc(C.getBitWidth())))
+        (CFP = llvm::isConstOrConstSplatFP(Op2, false)) &&
+        CFP->isExactlyValue(1.0)))
     return SDValue();
 
   SDValue LHS = Op1->getOperand(0);
@@ -13075,20 +13076,23 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   SDValue Op2 = N->getOperand(2);
 
   APInt ConstantOne;
-  if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
-      !(ConstantOne.isOne() ||
-        ConstantOne ==
-            APFloat(1.0f).bitcastToAPInt().trunc(ConstantOne.getBitWidth())))
+  ConstantFPSDNode *C;
+  if (!(N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA &&
+        (C = llvm::isConstOrConstSplatFP(Op2, false)) &&
+        C->isExactlyValue(1.0)) &&
+      !(ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) &&
+        ConstantOne.isOne()))
     return SDValue();
 
   unsigned Op1Opcode = Op1.getOpcode();
   if (!ISD::isExtOpcode(Op1Opcode))
     return SDValue();
 
-  bool Op1IsSigned = Op1Opcode != ISD::ZERO_EXTEND;
+  bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
   bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
-  if (Op1IsSigned != NodeIsSigned &&
+  if (N->getOpcode() != ISD::PARTIAL_REDUCE_FMLA &&
+      Op1IsSigned != NodeIsSigned &&
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 943f292d67fc8..9c3bd904b0462 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12076,8 +12076,12 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   }
   SDValue Input = MulLHS;
   APInt ConstantOne;
-  if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
-      !(ConstantOne.isOne() || ConstantOne == APFloat(1.0f).bitcastToAPInt()))
+  ConstantFPSDNode *C;
+  if (!(N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA &&
+        (C = llvm::isConstOrConstSplatFP(MulRHS, false)) &&
+        C->isExactlyValue(1.0)) &&
+      !(ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) &&
+        ConstantOne.isOne()))
     Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
 
   unsigned Stride = AccVT.getVectorMinNumElements();
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 8ef69ad13abc5..7fe4f7acdbd49 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4228,9 +4228,6 @@ defm FCLAMP_ZZZ : sve_fp_clamp<"fclamp", AArch64fclamp>;
 defm FDOT_ZZZ_S  : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, int_aarch64_sve_fdot_x2>;
 defm FDOT_ZZZI_S : sve_float_dot_indexed<0b0, 0b00, ZPR16, ZPR3b16, "fdot", nxv8f16, int_aarch64_sve_fdot_lane_x2>;
 
-def : Pat<(nxv4f32 (partial_reduce_fmla nxv4f32:$Acc, nxv8f16:$LHS, nxv8f16:$RHS)),
-          (FDOT_ZZZ_S $Acc, $LHS, $RHS)>;
-
 defm BFMLSLB_ZZZ_S : sve2_fp_mla_long<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb>;
 defm BFMLSLT_ZZZ_S : sve2_fp_mla_long<0b111, "bfmlslt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslt>;
 defm BFMLSLB_ZZZI_S : sve2_fp_mla_long_by_indexed_elem<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb_lane>;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 166219de9dfe9..7298cf16d95a3 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -9451,6 +9451,7 @@ multiclass sve_float_dot<bit bf, bit o2, ZPRRegOp dst_ty, ZPRRegOp src_ty,
                          string asm, ValueType InVT, SDPatternOperator op> {
   def NAME : sve_float_dot<bf, o2, dst_ty, src_ty, asm>;
   def : SVE_3_Op_Pat<nxv4f32, op, nxv4f32, InVT, InVT, !cast<Instruction>(NAME)>;
+  def : SVE_3_Op_Pat<nxv4f32, partial_reduce_fmla, nxv4f32, InVT, InVT, !cast<Instruction>(NAME)>;
 }
 
 multiclass sve_fp8_dot<bit bf, ZPRRegOp dstrc, string asm, ValueType vt,
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
index 5bb1fae43392f..69c0b68f23f78 100644
--- a/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
@@ -64,3 +64,69 @@ entry:
   %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult)
   ret <4 x float> %partial.reduce
 }
+
+define <8 x half> @partial_reduce_half(<8 x half> %acc, <16 x half> %a) {
+; CHECK-LABEL: partial_reduce_half:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    fadd v0.8h, v0.8h, v2.8h
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <8 x half> @llvm.vector.partial.reduce.fadd(<8 x half> %acc, <16 x half> %a)
+  ret <8 x half> %partial.reduce
+}
+
+define <4 x float> @partial_reduce_float(<4 x float> %acc, <8 x float> %a) {
+; CHECK-LABEL: partial_reduce_float:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    fadd v0.4s, v0.4s, v2.4s
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %a)
+  ret <4 x float> %partial.reduce
+}
+
+define <2 x double> @partial_reduce_double(<2 x double> %acc, <4 x double> %a) {
+; CHECK-LABEL: partial_reduce_double:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    fadd v0.2d, v0.2d, v2.2d
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <2 x double> @llvm.vector.partial.reduce.fadd(<2 x double> %acc, <4 x double> %a)
+  ret <2 x double> %partial.reduce
+}
+
+define <vscale x 8 x half> @partial_reduce_half_vl128(<vscale x 8 x half> %acc, <vscale x 16 x half> %a) {
+; CHECK-LABEL: partial_reduce_half_vl128:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd z0.h, z0.h, z1.h
+; CHECK-NEXT:    fadd z0.h, z0.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <vscale x 8 x half> @llvm.vector.partial.reduce.fadd(<vscale x 8 x half> %acc, <vscale x 16 x half> %a)
+  ret <vscale x 8 x half> %partial.reduce
+}
+
+define <vscale x 4 x float> @partial_reduce_float_vl128(<vscale x 4 x float> %acc, <vscale x 8 x float> %a) {
+; CHECK-LABEL: partial_reduce_float_vl128:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd z0.s, z0.s, z1.s
+; CHECK-NEXT:    fadd z0.s, z0.s, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %a)
+  ret <vscale x 4 x float> %partial.reduce
+}
+
+define <vscale x 2 x double> @partial_reduce_double_vl128(<vscale x 2 x double> %acc, <vscale x 4 x double> %a) {
+; CHECK-LABEL: partial_reduce_double_vl128:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd z0.d, z0.d, z1.d
+; CHECK-NEXT:    fadd z0.d, z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <vscale x 2 x double> @llvm.vector.partial.reduce.fadd(<vscale x 2 x double> %acc, <vscale x 4 x double> %a)
+  ret <vscale x 2 x double> %partial.reduce
+}

>From 319852132602f685aea6228f10418370fd530aa7 Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Tue, 23 Sep 2025 09:13:39 +0000
Subject: [PATCH 4/4] Require reassoc

---
 llvm/lib/IR/Verifier.cpp                 |  6 ++++++
 llvm/test/CodeGen/AArch64/sve2p1-fdot.ll | 18 +++++++++---------
 2 files changed, 15 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 17cbfa2458375..0ff4d229c5d9b 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6542,6 +6542,12 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
     }
     break;
   }
+  case Intrinsic::vector_partial_reduce_fadd: {
+    Check(Call.hasAllowReassoc(),
+          "vector_partial_reduce_fadd requires reassociation to be allowed.");
+    // Fall through to perform the same verification checks as for integers.
+    [[fallthrough]];
+  }
   case Intrinsic::vector_partial_reduce_add: {
     VectorType *AccTy = cast<VectorType>(Call.getArgOperand(0)->getType());
     VectorType *VecTy = cast<VectorType>(Call.getArgOperand(1)->getType());
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
index 69c0b68f23f78..aa2184ab6e65e 100644
--- a/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
@@ -10,7 +10,7 @@ entry:
   %a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
   %b.wide = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
   %mult = fmul <vscale x 8 x float> %a.wide, %b.wide
-  %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %mult)
+  %partial.reduce = call reassoc <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %mult)
   ret <vscale x 4 x float> %partial.reduce
 }
 
@@ -40,7 +40,7 @@ entry:
   %a.wide = fpext <16 x half> %a to <16 x float>
   %b.wide = fpext <16 x half> %b to <16 x float>
   %mult = fmul <16 x float> %a.wide, %b.wide
-  %partial.reduce = call <8 x float> @llvm.vector.partial.reduce.fadd(<8 x float> %acc, <16 x float> %mult)
+  %partial.reduce = call reassoc <8 x float> @llvm.vector.partial.reduce.fadd(<8 x float> %acc, <16 x float> %mult)
   store <8 x float> %partial.reduce, ptr %accptr
   ret void
 }
@@ -61,7 +61,7 @@ entry:
   %a.wide = fpext <8 x half> %a to <8 x float>
   %b.wide = fpext <8 x half> %b to <8 x float>
   %mult = fmul <8 x float> %a.wide, %b.wide
-  %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult)
+  %partial.reduce = call reassoc <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult)
   ret <4 x float> %partial.reduce
 }
 
@@ -72,7 +72,7 @@ define <8 x half> @partial_reduce_half(<8 x half> %acc, <16 x half> %a) {
 ; CHECK-NEXT:    fadd v0.8h, v0.8h, v2.8h
 ; CHECK-NEXT:    ret
 entry:
-  %partial.reduce = call <8 x half> @llvm.vector.partial.reduce.fadd(<8 x half> %acc, <16 x half> %a)
+  %partial.reduce = call reassoc <8 x half> @llvm.vector.partial.reduce.fadd(<8 x half> %acc, <16 x half> %a)
   ret <8 x half> %partial.reduce
 }
 
@@ -83,7 +83,7 @@ define <4 x float> @partial_reduce_float(<4 x float> %acc, <8 x float> %a) {
 ; CHECK-NEXT:    fadd v0.4s, v0.4s, v2.4s
 ; CHECK-NEXT:    ret
 entry:
-  %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %a)
+  %partial.reduce = call reassoc <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %a)
   ret <4 x float> %partial.reduce
 }
 
@@ -94,7 +94,7 @@ define <2 x double> @partial_reduce_double(<2 x double> %acc, <4 x double> %a) {
 ; CHECK-NEXT:    fadd v0.2d, v0.2d, v2.2d
 ; CHECK-NEXT:    ret
 entry:
-  %partial.reduce = call <2 x double> @llvm.vector.partial.reduce.fadd(<2 x double> %acc, <4 x double> %a)
+  %partial.reduce = call reassoc <2 x double> @llvm.vector.partial.reduce.fadd(<2 x double> %acc, <4 x double> %a)
   ret <2 x double> %partial.reduce
 }
 
@@ -105,7 +105,7 @@ define <vscale x 8 x half> @partial_reduce_half_vl128(<vscale x 8 x half> %acc,
 ; CHECK-NEXT:    fadd z0.h, z0.h, z2.h
 ; CHECK-NEXT:    ret
 entry:
-  %partial.reduce = call <vscale x 8 x half> @llvm.vector.partial.reduce.fadd(<vscale x 8 x half> %acc, <vscale x 16 x half> %a)
+  %partial.reduce = call reassoc <vscale x 8 x half> @llvm.vector.partial.reduce.fadd(<vscale x 8 x half> %acc, <vscale x 16 x half> %a)
   ret <vscale x 8 x half> %partial.reduce
 }
 
@@ -116,7 +116,7 @@ define <vscale x 4 x float> @partial_reduce_float_vl128(<vscale x 4 x float> %ac
 ; CHECK-NEXT:    fadd z0.s, z0.s, z2.s
 ; CHECK-NEXT:    ret
 entry:
-  %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %a)
+  %partial.reduce = call reassoc <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %a)
   ret <vscale x 4 x float> %partial.reduce
 }
 
@@ -127,6 +127,6 @@ define <vscale x 2 x double> @partial_reduce_double_vl128(<vscale x 2 x double>
 ; CHECK-NEXT:    fadd z0.d, z0.d, z2.d
 ; CHECK-NEXT:    ret
 entry:
-  %partial.reduce = call <vscale x 2 x double> @llvm.vector.partial.reduce.fadd(<vscale x 2 x double> %acc, <vscale x 4 x double> %a)
+  %partial.reduce = call reassoc <vscale x 2 x double> @llvm.vector.partial.reduce.fadd(<vscale x 2 x double> %acc, <vscale x 4 x double> %a)
   ret <vscale x 2 x double> %partial.reduce
 }



More information about the llvm-commits mailing list