[llvm] 70f4b59 - Add `llvm.vector.partial.reduce.fadd` intrinsic (#159776)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 7 07:36:58 PST 2025


Author: Damian Heaton
Date: 2025-11-07T15:36:54Z
New Revision: 70f4b596cf453369ce4111c23e7e93633e5fe4b1

URL: https://github.com/llvm/llvm-project/commit/70f4b596cf453369ce4111c23e7e93633e5fe4b1
DIFF: https://github.com/llvm/llvm-project/commit/70f4b596cf453369ce4111c23e7e93633e5fe4b1.diff

LOG: Add `llvm.vector.partial.reduce.fadd` intrinsic (#159776)

With this intrinsic, and supporting SelectionDAG nodes, we can better
make use of instructions such as AArch64's `FDOT`.

Added: 
    llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
    llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll

Modified: 
    llvm/docs/LangRef.rst
    llvm/include/llvm/CodeGen/ISDOpcodes.h
    llvm/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/include/llvm/IR/Intrinsics.td
    llvm/include/llvm/Target/TargetSelectionDAG.td
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/IR/Verifier.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index bd0337f25c758..ab085ca0b1499 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20368,6 +20368,77 @@ Arguments:
 """"""""""
 The argument to this intrinsic must be a vector of floating-point values.
 
+Vector Partial Reduction Intrinsics
+-----------------------------------
+
+Partial reductions of vectors can be expressed using the intrinsics described in
+this section. Each one reduces the concatenation of the two vector arguments
+down to the number of elements of the result vector type.
+
+Other than the reduction operator (e.g. add, fadd), the way in which the
+concatenated arguments is reduced is entirely unspecified. By their nature these
+intrinsics are not expected to be useful in isolation but can instead be used to
+implement the first phase of an overall reduction operation.
+
+The typical use case is loop vectorization where reductions are split into an
+in-loop phase, where maintaining an unordered vector result is important for
+performance, and an out-of-loop phase is required to calculate the final scalar
+result.
+
+By avoiding the introduction of new ordering constraints, these intrinsics
+enhance the ability to leverage a target's accumulation instructions.
+
+'``llvm.vector.partial.reduce.add.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v8i32(<4 x i32> %a, <8 x i32> %b)
+      declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v16i32(<4 x i32> %a, <16 x i32> %b)
+      declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv8i32(<vscale x 4 x i32> %a, <vscale x 8 x i32> %b)
+      declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv16i32(<vscale x 4 x i32> %a, <vscale x 16 x i32> %b)
+
+Arguments:
+""""""""""
+
+The first argument is an integer vector with the same type as the result.
+
+The second argument is a vector with a length that is a known integer multiple
+of the result's type, while maintaining the same element type.
+
+'``llvm.vector.partial.reduce.fadd.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <4 x f32> @llvm.vector.partial.reduce.fadd.v4f32.v8f32(<4 x f32> %a, <8 x f32> %b)
+      declare <vscale x 4 x f32> @llvm.vector.partial.reduce.fadd.nxv4f32.nxv8f32(<vscale x 4 x f32> %a, <vscale x 8 x f32> %b)
+
+Arguments:
+""""""""""
+
+The first argument is a floating-point vector with the same type as the result.
+
+The second argument is a vector with a length that is a known integer multiple
+of the result's type, while maintaining the same element type.
+
+Semantics:
+""""""""""
+
+As the way in which the arguments to this floating-point intrinsic are reduced
+is unspecified, this intrinsic will assume floating-point reassociation and
+contraction can be leveraged to implement the reduction, which may result in
+variations to the results due to reordering or by lowering to 
diff erent
+instructions (including combining multiple instructions into a single one).
+
 '``llvm.vector.insert``' Intrinsic
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
@@ -20741,50 +20812,6 @@ Note that it has the following implications:
 -  If ``%cnt`` is non-zero, the return value is non-zero as well.
 -  If ``%cnt`` is less than or equal to ``%max_lanes``, the return value is equal to ``%cnt``.
 
-'``llvm.vector.partial.reduce.add.*``' Intrinsic
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-Syntax:
-"""""""
-This is an overloaded intrinsic.
-
-::
-
-      declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v8i32(<4 x i32> %a, <8 x i32> %b)
-      declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v16i32(<4 x i32> %a, <16 x i32> %b)
-      declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv8i32(<vscale x 4 x i32> %a, <vscale x 8 x i32> %b)
-      declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv16i32(<vscale x 4 x i32> %a, <vscale x 16 x i32> %b)
-
-Overview:
-"""""""""
-
-The '``llvm.vector.partial.reduce.add.*``' intrinsics reduce the
-concatenation of the two vector arguments down to the number of elements of the
-result vector type.
-
-Arguments:
-""""""""""
-
-The first argument is an integer vector with the same type as the result.
-
-The second argument is a vector with a length that is a known integer multiple
-of the result's type, while maintaining the same element type.
-
-Semantics:
-""""""""""
-
-Other than the reduction operator (e.g., add) the way in which the concatenated
-arguments is reduced is entirely unspecified. By their nature these intrinsics
-are not expected to be useful in isolation but instead implement the first phase
-of an overall reduction operation.
-
-The typical use case is loop vectorization where reductions are split into an
-in-loop phase, where maintaining an unordered vector result is important for
-performance, and an out-of-loop phase to calculate the final scalar result.
-
-By avoiding the introduction of new ordering constraints, these intrinsics
-enhance the ability to leverage a target's accumulation instructions.
-
 '``llvm.experimental.vector.histogram.*``' Intrinsic
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 

diff  --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index af60c04513ea7..b3a2ced70e628 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1516,6 +1516,7 @@ enum NodeType {
   PARTIAL_REDUCE_SMLA,  // sext, sext
   PARTIAL_REDUCE_UMLA,  // zext, zext
   PARTIAL_REDUCE_SUMLA, // sext, zext
+  PARTIAL_REDUCE_FMLA,  // fpext, fpext
 
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]

diff  --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 1759463ea7965..cd466dceb900f 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1949,6 +1949,10 @@ LLVM_ABI bool isNullOrNullSplat(SDValue V, bool AllowUndefs = false);
 /// be zero.
 LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
 
+/// Return true if the value is a constant floating-point value, or a splatted
+/// vector of a constant floating-point value, of 1.0 (with no undefs).
+LLVM_ABI bool isOneOrOneSplatFP(SDValue V, bool AllowUndefs = false);
+
 /// Return true if the value is a constant -1 integer or a splatted vector of a
 /// constant -1 integer (with no undefs).
 /// Does not permit build vector implicit truncation.

diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 2550c2bee5f71..98565f423df3e 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1679,7 +1679,7 @@ class LLVM_ABI TargetLoweringBase {
   LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
                                            EVT InputVT) const {
     assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
-           Opc == ISD::PARTIAL_REDUCE_SUMLA);
+           Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
     PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
                                     InputVT.getSimpleVT().SimpleTy};
     auto It = PartialReduceMLAActions.find(Key);
@@ -2793,7 +2793,7 @@ class LLVM_ABI TargetLoweringBase {
   void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
                                  LegalizeAction Action) {
     assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
-           Opc == ISD::PARTIAL_REDUCE_SUMLA);
+           Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
     assert(AccVT.isValid() && InputVT.isValid() &&
            "setPartialReduceMLAAction types aren't valid");
     PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};

diff  --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 520130c6dcb4c..07aa2faffa7c5 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2813,6 +2813,10 @@ def int_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
                                                           [IntrNoMem,
                                                            IntrSpeculatable]>;
 
+def int_vector_partial_reduce_fadd : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                                                           [llvm_anyfloat_ty, llvm_anyfloat_ty],
+                                                           [IntrNoMem]>;
+
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 

diff  --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 07a858fd682fc..a9750a5ab03f9 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -527,6 +527,8 @@ def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
                                  SDTPartialReduceMLA>;
 def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
                                  SDTPartialReduceMLA>;
+def partial_reduce_fmla : SDNode<"ISD::PARTIAL_REDUCE_FMLA",
+                                 SDTPartialReduceMLA>;
 
 def fadd       : SDNode<"ISD::FADD"       , SDTFPBinOp, [SDNPCommutative]>;
 def fsub       : SDNode<"ISD::FSUB"       , SDTFPBinOp>;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 10daca57558ff..f144f17d5a8f2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2042,6 +2042,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
                                 return visitPARTIAL_REDUCE_MLA(N);
   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
@@ -13006,6 +13007,9 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
 //
 // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
 // -> partial_reduce_*mla(acc, x, C)
+//
+// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
+// -> partial_reduce_fmla(acc, a, b)
 SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDLoc DL(N);
   auto *Context = DAG.getContext();
@@ -13014,7 +13018,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue Op2 = N->getOperand(2);
 
   unsigned Opc = Op1->getOpcode();
-  if (Opc != ISD::MUL && Opc != ISD::SHL)
+  if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
     return SDValue();
 
   SDValue LHS = Op1->getOperand(0);
@@ -13033,13 +13037,16 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
     Opc = ISD::MUL;
   }
 
-  APInt C;
-  if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
-      !C.isOne())
+  if (!(Opc == ISD::MUL && llvm::isOneOrOneSplat(Op2)) &&
+      !(Opc == ISD::FMUL && llvm::isOneOrOneSplatFP(Op2)))
     return SDValue();
 
+  auto IsIntOrFPExtOpcode = [](unsigned int Opcode) {
+    return (ISD::isExtOpcode(Opcode) || Opcode == ISD::FP_EXTEND);
+  };
+
   unsigned LHSOpcode = LHS->getOpcode();
-  if (!ISD::isExtOpcode(LHSOpcode))
+  if (!IsIntOrFPExtOpcode(LHSOpcode))
     return SDValue();
 
   SDValue LHSExtOp = LHS->getOperand(0);
@@ -13047,6 +13054,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
 
   // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
   // -> partial_reduce_*mla(acc, x, C)
+  APInt C;
   if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
     // TODO: Make use of partial_reduce_sumla here
     APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
@@ -13071,7 +13079,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   }
 
   unsigned RHSOpcode = RHS->getOpcode();
-  if (!ISD::isExtOpcode(RHSOpcode))
+  if (!IsIntOrFPExtOpcode(RHSOpcode))
     return SDValue();
 
   SDValue RHSExtOp = RHS->getOperand(0);
@@ -13088,6 +13096,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
     NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
     std::swap(LHSExtOp, RHSExtOp);
+  } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
+    NewOpc = ISD::PARTIAL_REDUCE_FMLA;
   } else
     return SDValue();
   // For a 2-stage extend the signedness of both of the extends must match
@@ -13115,30 +13125,33 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
 // -> partial.reduce.smla(acc, op, splat(trunc(1)))
 // partial.reduce.sumla(acc, sext(op), splat(1))
 // -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.fmla(acc, fpext(op), splat(1.0))
+// -> partial.reduce.fmla(acc, op, splat(1.0))
 SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   SDLoc DL(N);
   SDValue Acc = N->getOperand(0);
   SDValue Op1 = N->getOperand(1);
   SDValue Op2 = N->getOperand(2);
 
-  APInt ConstantOne;
-  if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+  if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
     return SDValue();
 
   unsigned Op1Opcode = Op1.getOpcode();
-  if (!ISD::isExtOpcode(Op1Opcode))
+  if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
     return SDValue();
 
-  bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+  bool Op1IsSigned =
+      Op1Opcode == ISD::SIGN_EXTEND || Op1Opcode == ISD::FP_EXTEND;
   bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
   if (Op1IsSigned != NodeIsSigned &&
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  unsigned NewOpcode =
-      Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+  unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+                           ? ISD::PARTIAL_REDUCE_FMLA
+                       : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
+                                     : ISD::PARTIAL_REDUCE_UMLA;
 
   SDValue UnextOp1 = Op1.getOperand(0);
   EVT UnextOp1VT = UnextOp1.getValueType();
@@ -13148,8 +13161,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
           TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
     return SDValue();
 
+  SDValue Constant = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+                         ? DAG.getConstantFP(1, DL, UnextOp1VT)
+                         : DAG.getConstant(1, DL, UnextOp1VT);
+
   return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
-                     DAG.getConstant(1, DL, UnextOp1VT));
+                     Constant);
 }
 
 SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8e423c4f83b38..94751be5b7986 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -534,6 +534,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Action =
         TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
                                       Node->getOperand(1).getValueType());
@@ -1243,6 +1244,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
     return;
   case ISD::VECREDUCE_SEQ_FADD:

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index bb4a8d9967f94..dd5c011bfe784 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1474,6 +1474,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
     break;
   case ISD::GET_ACTIVE_LANE_MASK:
@@ -3689,6 +3690,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
     break;
   }

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 0a06752ae1846..bbc1d734cfef5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8404,7 +8404,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   }
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
-  case ISD::PARTIAL_REDUCE_SUMLA: {
+  case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA: {
     [[maybe_unused]] EVT AccVT = N1.getValueType();
     [[maybe_unused]] EVT Input1VT = N2.getValueType();
     [[maybe_unused]] EVT Input2VT = N3.getValueType();
@@ -13064,6 +13065,11 @@ bool llvm::isOneOrOneSplat(SDValue N, bool AllowUndefs) {
   return C && C->isOne();
 }
 
+bool llvm::isOneOrOneSplatFP(SDValue N, bool AllowUndefs) {
+  ConstantFPSDNode *C = isConstOrConstSplatFP(N, AllowUndefs);
+  return C && C->isExactlyValue(1.0);
+}
+
 bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
   N = peekThroughBitcasts(N);
   unsigned BitWidth = N.getScalarValueSizeInBits();

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 2f598b2f03e24..88b0809b767b5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8187,6 +8187,14 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
                          Input, DAG.getConstant(1, sdl, Input.getValueType())));
     return;
   }
+  case Intrinsic::vector_partial_reduce_fadd: {
+    SDValue Acc = getValue(I.getOperand(0));
+    SDValue Input = getValue(I.getOperand(1));
+    setValue(&I, DAG.getNode(
+                     ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc,
+                     Input, DAG.getConstantFP(1.0, sdl, Input.getValueType())));
+    return;
+  }
   case Intrinsic::experimental_cttz_elts: {
     auto DL = getCurSDLoc();
     SDValue Op = getValue(I.getOperand(0));

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index d3e162879304d..ec5edd5f13978 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -590,6 +590,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
     return "partial_reduce_smla";
   case ISD::PARTIAL_REDUCE_SUMLA:
     return "partial_reduce_sumla";
+  case ISD::PARTIAL_REDUCE_FMLA:
+    return "partial_reduce_fmla";
   case ISD::LOOP_DEPENDENCE_WAR_MASK:
     return "loop_dep_war";
   case ISD::LOOP_DEPENDENCE_RAW_MASK:

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 9bdf82210fed1..b51d6649af2ec 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12074,22 +12074,32 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
       EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
                        MulOpVT.getVectorElementCount());
 
-  unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
-                      ? ISD::ZERO_EXTEND
-                      : ISD::SIGN_EXTEND;
-  unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
-                      ? ISD::SIGN_EXTEND
-                      : ISD::ZERO_EXTEND;
+  unsigned ExtOpcLHS, ExtOpcRHS;
+  switch (N->getOpcode()) {
+  default:
+    llvm_unreachable("Unexpected opcode");
+  case ISD::PARTIAL_REDUCE_UMLA:
+    ExtOpcLHS = ExtOpcRHS = ISD::ZERO_EXTEND;
+    break;
+  case ISD::PARTIAL_REDUCE_SMLA:
+    ExtOpcLHS = ExtOpcRHS = ISD::SIGN_EXTEND;
+    break;
+  case ISD::PARTIAL_REDUCE_FMLA:
+    ExtOpcLHS = ExtOpcRHS = ISD::FP_EXTEND;
+    break;
+  }
 
   if (ExtMulOpVT != MulOpVT) {
     MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
     MulRHS = DAG.getNode(ExtOpcRHS, DL, ExtMulOpVT, MulRHS);
   }
   SDValue Input = MulLHS;
-  APInt ConstantOne;
-  if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+  if (N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA) {
+    if (!llvm::isOneOrOneSplatFP(MulRHS))
+      Input = DAG.getNode(ISD::FMUL, DL, ExtMulOpVT, MulLHS, MulRHS);
+  } else if (!llvm::isOneOrOneSplat(MulRHS)) {
     Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
+  }
 
   unsigned Stride = AccVT.getVectorMinNumElements();
   unsigned ScaleFactor = MulOpVT.getVectorMinNumElements() / Stride;
@@ -12099,10 +12109,13 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   for (unsigned I = 0; I < ScaleFactor; I++)
     Subvectors.push_back(DAG.getExtractSubvector(DL, AccVT, Input, I * Stride));
 
+  unsigned FlatNode =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FADD : ISD::ADD;
+
   // Flatten the subvector tree
   while (Subvectors.size() > 1) {
     Subvectors.push_back(
-        DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]}));
+        DAG.getNode(FlatNode, DL, AccVT, {Subvectors[0], Subvectors[1]}));
     Subvectors.pop_front();
     Subvectors.pop_front();
   }

diff  --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index f1e473a78e3ac..59eb870798235 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6583,6 +6583,7 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
     }
     break;
   }
+  case Intrinsic::vector_partial_reduce_fadd:
   case Intrinsic::vector_partial_reduce_add: {
     VectorType *AccTy = cast<VectorType>(Call.getArgOperand(0)->getType());
     VectorType *VecTy = cast<VectorType>(Call.getArgOperand(1)->getType());

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 40e6400756c7c..c8a038fa99b30 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1916,6 +1916,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal);
       setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal);
     }
+
+    // Handle floating-point partial reduction
+    if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
+      setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::nxv4f32,
+                                MVT::nxv8f16, Legal);
+    }
   }
 
   // Handle non-aliasing elements mask
@@ -2283,6 +2289,11 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
                                 MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
   }
 
+  if (Subtarget->hasSVE2p1() && VT.getVectorElementType() == MVT::f32) {
+    setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, VT,
+                              MVT::getVectorVT(MVT::f16, NumElts * 2), Custom);
+  }
+
   // Lower fixed length vector operations to scalable equivalents.
   setOperationAction(ISD::ABDS, VT, Default);
   setOperationAction(ISD::ABDU, VT, Default);
@@ -7875,6 +7886,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     return LowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 3b268dcbca600..e1f43867bbe5b 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -375,6 +375,11 @@ def AArch64fclamp : PatFrags<(ops node:$Zd, node:$Zn, node:$Zm),
                                node:$Zm)
                                ]>;
 
+def AArch64fdot : PatFrags<(ops node:$Zd, node:$Zn, node:$Zm),
+                            [(int_aarch64_sve_fdot_x2 node:$Zd, node:$Zn, node:$Zm),
+                             (partial_reduce_fmla node:$Zd, node:$Zn, node:$Zm)
+                            ]>;
+
 def SDT_AArch64FCVT : SDTypeProfile<1, 3, [
   SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>,
   SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1>, SDTCisSameAs<0,3>
@@ -4251,7 +4256,7 @@ defm PSEL_PPPRI : sve2_int_perm_sel_p<"psel", int_aarch64_sve_psel>;
 let Predicates = [HasSVE2p1_or_SME2] in {
 defm FCLAMP_ZZZ : sve_fp_clamp<"fclamp", AArch64fclamp>;
 
-defm FDOT_ZZZ_S  : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, int_aarch64_sve_fdot_x2>;
+defm FDOT_ZZZ_S  : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, AArch64fdot>;
 defm FDOT_ZZZI_S : sve_float_dot_indexed<0b0, 0b00, ZPR16, ZPR3b16, "fdot", nxv8f16, int_aarch64_sve_fdot_lane_x2>;
 
 defm BFMLSLB_ZZZ_S : sve2_fp_mla_long<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb>;

diff  --git a/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
new file mode 100644
index 0000000000000..9dbe096ebdb57
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
@@ -0,0 +1,93 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mattr=+sve2 < %s | FileCheck %s --check-prefixes=CHECK,SVE2
+; RUN: llc -global-isel -global-isel-abort=2 -mattr=+sve2 < %s | FileCheck %s --check-prefixes=CHECK,SVE2
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s --check-prefixes=CHECK,SVE2P1
+; RUN: llc -global-isel -global-isel-abort=2 -mattr=+sve2p1 < %s | FileCheck %s --check-prefixes=CHECK,SVE2P1
+
+target triple = "aarch64-linux-gnu"
+
+define <vscale x 4 x float> @fdot_wide_nxv4f32(<vscale x 4 x float> %acc, <vscale x 8 x half> %a, <vscale x 8 x half> %b) {
+; SVE2-LABEL: fdot_wide_nxv4f32:
+; SVE2:       // %bb.0: // %entry
+; SVE2-NEXT:    uunpklo z3.s, z1.h
+; SVE2-NEXT:    uunpklo z4.s, z2.h
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    uunpkhi z1.s, z1.h
+; SVE2-NEXT:    uunpkhi z2.s, z2.h
+; SVE2-NEXT:    fcvt z3.s, p0/m, z3.h
+; SVE2-NEXT:    fcvt z4.s, p0/m, z4.h
+; SVE2-NEXT:    fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT:    fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT:    fmul z3.s, z3.s, z4.s
+; SVE2-NEXT:    fmul z1.s, z1.s, z2.s
+; SVE2-NEXT:    fadd z0.s, z0.s, z3.s
+; SVE2-NEXT:    fadd z0.s, z0.s, z1.s
+; SVE2-NEXT:    ret
+;
+; SVE2P1-LABEL: fdot_wide_nxv4f32:
+; SVE2P1:       // %bb.0: // %entry
+; SVE2P1-NEXT:    fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT:    ret
+entry:
+  %a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+  %b.wide = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+  %mult = fmul <vscale x 8 x float> %a.wide, %b.wide
+  %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %mult)
+  ret <vscale x 4 x float> %partial.reduce
+}
+
+define <vscale x 4 x float> @fdot_splat_nxv4f32(<vscale x 4 x float> %acc, <vscale x 8 x half> %a) {
+; SVE2-LABEL: fdot_splat_nxv4f32:
+; SVE2:       // %bb.0: // %entry
+; SVE2-NEXT:    uunpklo z2.s, z1.h
+; SVE2-NEXT:    ptrue p0.s
+; SVE2-NEXT:    uunpkhi z1.s, z1.h
+; SVE2-NEXT:    fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT:    fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT:    fadd z0.s, z0.s, z2.s
+; SVE2-NEXT:    fadd z0.s, z0.s, z1.s
+; SVE2-NEXT:    ret
+;
+; SVE2P1-LABEL: fdot_splat_nxv4f32:
+; SVE2P1:       // %bb.0: // %entry
+; SVE2P1-NEXT:    fmov z2.h, #1.00000000
+; SVE2P1-NEXT:    fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT:    ret
+entry:
+  %a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+  %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %a.wide)
+  ret <vscale x 4 x float> %partial.reduce
+}
+
+define <vscale x 8 x half> @partial_reduce_nxv8f16(<vscale x 8 x half> %acc, <vscale x 16 x half> %a) {
+; CHECK-LABEL: partial_reduce_nxv8f16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd z0.h, z0.h, z1.h
+; CHECK-NEXT:    fadd z0.h, z0.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <vscale x 8 x half> @llvm.vector.partial.reduce.fadd(<vscale x 8 x half> %acc, <vscale x 16 x half> %a)
+  ret <vscale x 8 x half> %partial.reduce
+}
+
+define <vscale x 4 x float> @partial_reduce_nxv4f32(<vscale x 4 x float> %acc, <vscale x 8 x float> %a) {
+; CHECK-LABEL: partial_reduce_nxv4f32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd z0.s, z0.s, z1.s
+; CHECK-NEXT:    fadd z0.s, z0.s, z2.s
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %a)
+  ret <vscale x 4 x float> %partial.reduce
+}
+
+define <vscale x 2 x double> @partial_reduce_nxv2f64(<vscale x 2 x double> %acc, <vscale x 4 x double> %a) {
+; CHECK-LABEL: partial_reduce_nxv2f64:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd z0.d, z0.d, z1.d
+; CHECK-NEXT:    fadd z0.d, z0.d, z2.d
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <vscale x 2 x double> @llvm.vector.partial.reduce.fadd(<vscale x 2 x double> %acc, <vscale x 4 x double> %a)
+  ret <vscale x 2 x double> %partial.reduce
+}

diff  --git a/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll
new file mode 100644
index 0000000000000..89216ce2cb72b
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll
@@ -0,0 +1,230 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mattr=+sve2 < %s | FileCheck %s --check-prefixes=CHECK,SVE2
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s --check-prefixes=CHECK,SVE2P1
+
+target triple = "aarch64-linux-gnu"
+
+define void @fdot_wide_v8f32(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,0) {
+; SVE2-LABEL: fdot_wide_v8f32:
+; SVE2:       // %bb.0: // %entry
+; SVE2-NEXT:    ptrue p0.s, vl8
+; SVE2-NEXT:    mov x8, #8 // =0x8
+; SVE2-NEXT:    ld1h { z0.s }, p0/z, [x1]
+; SVE2-NEXT:    ld1h { z1.s }, p0/z, [x2]
+; SVE2-NEXT:    ld1h { z2.s }, p0/z, [x1, x8, lsl #1]
+; SVE2-NEXT:    ld1h { z3.s }, p0/z, [x2, x8, lsl #1]
+; SVE2-NEXT:    fcvt z0.s, p0/m, z0.h
+; SVE2-NEXT:    fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT:    fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT:    fcvt z3.s, p0/m, z3.h
+; SVE2-NEXT:    fmul z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT:    ld1w { z1.s }, p0/z, [x0]
+; SVE2-NEXT:    fmul z2.s, p0/m, z2.s, z3.s
+; SVE2-NEXT:    fadd z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT:    fadd z0.s, p0/m, z0.s, z2.s
+; SVE2-NEXT:    st1w { z0.s }, p0, [x0]
+; SVE2-NEXT:    ret
+;
+; SVE2P1-LABEL: fdot_wide_v8f32:
+; SVE2P1:       // %bb.0: // %entry
+; SVE2P1-NEXT:    ptrue p0.s, vl8
+; SVE2P1-NEXT:    ptrue p1.h, vl16
+; SVE2P1-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; SVE2P1-NEXT:    ld1h { z1.h }, p1/z, [x1]
+; SVE2P1-NEXT:    ld1h { z2.h }, p1/z, [x2]
+; SVE2P1-NEXT:    fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT:    st1w { z0.s }, p0, [x0]
+; SVE2P1-NEXT:    ret
+entry:
+  %acc = load <8 x float>, ptr %accptr
+  %a = load <16 x half>, ptr %aptr
+  %b = load <16 x half>, ptr %bptr
+  %a.wide = fpext <16 x half> %a to <16 x float>
+  %b.wide = fpext <16 x half> %b to <16 x float>
+  %mult = fmul <16 x float> %a.wide, %b.wide
+  %partial.reduce = call <8 x float> @llvm.vector.partial.reduce.fadd(<8 x float> %acc, <16 x float> %mult)
+  store <8 x float> %partial.reduce, ptr %accptr
+  ret void
+}
+
+define void @fdot_wide_v16f32(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(4,0) {
+; SVE2-LABEL: fdot_wide_v16f32:
+; SVE2:       // %bb.0: // %entry
+; SVE2-NEXT:    ptrue p0.s, vl16
+; SVE2-NEXT:    mov x8, #16 // =0x10
+; SVE2-NEXT:    ld1h { z0.s }, p0/z, [x1]
+; SVE2-NEXT:    ld1h { z1.s }, p0/z, [x2]
+; SVE2-NEXT:    ld1h { z2.s }, p0/z, [x1, x8, lsl #1]
+; SVE2-NEXT:    ld1h { z3.s }, p0/z, [x2, x8, lsl #1]
+; SVE2-NEXT:    fcvt z0.s, p0/m, z0.h
+; SVE2-NEXT:    fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT:    fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT:    fcvt z3.s, p0/m, z3.h
+; SVE2-NEXT:    fmul z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT:    ld1w { z1.s }, p0/z, [x0]
+; SVE2-NEXT:    fmul z2.s, p0/m, z2.s, z3.s
+; SVE2-NEXT:    fadd z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT:    fadd z0.s, p0/m, z0.s, z2.s
+; SVE2-NEXT:    st1w { z0.s }, p0, [x0]
+; SVE2-NEXT:    ret
+;
+; SVE2P1-LABEL: fdot_wide_v16f32:
+; SVE2P1:       // %bb.0: // %entry
+; SVE2P1-NEXT:    ptrue p0.s, vl16
+; SVE2P1-NEXT:    ptrue p1.h, vl32
+; SVE2P1-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; SVE2P1-NEXT:    ld1h { z1.h }, p1/z, [x1]
+; SVE2P1-NEXT:    ld1h { z2.h }, p1/z, [x2]
+; SVE2P1-NEXT:    fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT:    st1w { z0.s }, p0, [x0]
+; SVE2P1-NEXT:    ret
+entry:
+  %acc = load <16 x float>, ptr %accptr
+  %a = load <32 x half>, ptr %aptr
+  %b = load <32 x half>, ptr %bptr
+  %a.wide = fpext <32 x half> %a to <32 x float>
+  %b.wide = fpext <32 x half> %b to <32 x float>
+  %mult = fmul <32 x float> %a.wide, %b.wide
+  %partial.reduce = call <16 x float> @llvm.vector.partial.reduce.fadd(<16 x float> %acc, <32 x float> %mult)
+  store <16 x float> %partial.reduce, ptr %accptr
+  ret void
+}
+
+define void @fdot_wide_v32f32(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(8,0) {
+; SVE2-LABEL: fdot_wide_v32f32:
+; SVE2:       // %bb.0: // %entry
+; SVE2-NEXT:    ptrue p0.s, vl32
+; SVE2-NEXT:    mov x8, #32 // =0x20
+; SVE2-NEXT:    ld1h { z0.s }, p0/z, [x1]
+; SVE2-NEXT:    ld1h { z1.s }, p0/z, [x2]
+; SVE2-NEXT:    ld1h { z2.s }, p0/z, [x1, x8, lsl #1]
+; SVE2-NEXT:    ld1h { z3.s }, p0/z, [x2, x8, lsl #1]
+; SVE2-NEXT:    fcvt z0.s, p0/m, z0.h
+; SVE2-NEXT:    fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT:    fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT:    fcvt z3.s, p0/m, z3.h
+; SVE2-NEXT:    fmul z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT:    ld1w { z1.s }, p0/z, [x0]
+; SVE2-NEXT:    fmul z2.s, p0/m, z2.s, z3.s
+; SVE2-NEXT:    fadd z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT:    fadd z0.s, p0/m, z0.s, z2.s
+; SVE2-NEXT:    st1w { z0.s }, p0, [x0]
+; SVE2-NEXT:    ret
+;
+; SVE2P1-LABEL: fdot_wide_v32f32:
+; SVE2P1:       // %bb.0: // %entry
+; SVE2P1-NEXT:    ptrue p0.s, vl32
+; SVE2P1-NEXT:    ptrue p1.h, vl64
+; SVE2P1-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; SVE2P1-NEXT:    ld1h { z1.h }, p1/z, [x1]
+; SVE2P1-NEXT:    ld1h { z2.h }, p1/z, [x2]
+; SVE2P1-NEXT:    fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT:    st1w { z0.s }, p0, [x0]
+; SVE2P1-NEXT:    ret
+entry:
+  %acc = load <32 x float>, ptr %accptr
+  %a = load <64 x half>, ptr %aptr
+  %b = load <64 x half>, ptr %bptr
+  %a.wide = fpext <64 x half> %a to <64 x float>
+  %b.wide = fpext <64 x half> %b to <64 x float>
+  %mult = fmul <64 x float> %a.wide, %b.wide
+  %partial.reduce = call <32 x float> @llvm.vector.partial.reduce.fadd(<32 x float> %acc, <64 x float> %mult)
+  store <32 x float> %partial.reduce, ptr %accptr
+  ret void
+}
+
+define void @fdot_wide_v64f32(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(16,0) {
+; SVE2-LABEL: fdot_wide_v64f32:
+; SVE2:       // %bb.0: // %entry
+; SVE2-NEXT:    ptrue p0.s, vl64
+; SVE2-NEXT:    mov x8, #64 // =0x40
+; SVE2-NEXT:    ld1h { z0.s }, p0/z, [x1]
+; SVE2-NEXT:    ld1h { z1.s }, p0/z, [x2]
+; SVE2-NEXT:    ld1h { z2.s }, p0/z, [x1, x8, lsl #1]
+; SVE2-NEXT:    ld1h { z3.s }, p0/z, [x2, x8, lsl #1]
+; SVE2-NEXT:    fcvt z0.s, p0/m, z0.h
+; SVE2-NEXT:    fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT:    fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT:    fcvt z3.s, p0/m, z3.h
+; SVE2-NEXT:    fmul z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT:    ld1w { z1.s }, p0/z, [x0]
+; SVE2-NEXT:    fmul z2.s, p0/m, z2.s, z3.s
+; SVE2-NEXT:    fadd z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT:    fadd z0.s, p0/m, z0.s, z2.s
+; SVE2-NEXT:    st1w { z0.s }, p0, [x0]
+; SVE2-NEXT:    ret
+;
+; SVE2P1-LABEL: fdot_wide_v64f32:
+; SVE2P1:       // %bb.0: // %entry
+; SVE2P1-NEXT:    ptrue p0.s, vl64
+; SVE2P1-NEXT:    ptrue p1.h, vl128
+; SVE2P1-NEXT:    ld1w { z0.s }, p0/z, [x0]
+; SVE2P1-NEXT:    ld1h { z1.h }, p1/z, [x1]
+; SVE2P1-NEXT:    ld1h { z2.h }, p1/z, [x2]
+; SVE2P1-NEXT:    fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT:    st1w { z0.s }, p0, [x0]
+; SVE2P1-NEXT:    ret
+entry:
+  %acc = load <64 x float>, ptr %accptr
+  %a = load <128 x half>, ptr %aptr
+  %b = load <128 x half>, ptr %bptr
+  %a.wide = fpext <128 x half> %a to <128 x float>
+  %b.wide = fpext <128 x half> %b to <128 x float>
+  %mult = fmul <128 x float> %a.wide, %b.wide
+  %partial.reduce = call <64 x float> @llvm.vector.partial.reduce.fadd(<64 x float> %acc, <128 x float> %mult)
+  store <64 x float> %partial.reduce, ptr %accptr
+  ret void
+}
+
+define <4 x float> @fixed_fdot_wide(<4 x float> %acc, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: fixed_fdot_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fcvtl v3.4s, v1.4h
+; CHECK-NEXT:    fcvtl v4.4s, v2.4h
+; CHECK-NEXT:    fcvtl2 v1.4s, v1.8h
+; CHECK-NEXT:    fcvtl2 v2.4s, v2.8h
+; CHECK-NEXT:    fmul v3.4s, v3.4s, v4.4s
+; CHECK-NEXT:    fmul v1.4s, v1.4s, v2.4s
+; CHECK-NEXT:    fadd v0.4s, v0.4s, v3.4s
+; CHECK-NEXT:    fadd v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = fpext <8 x half> %a to <8 x float>
+  %b.wide = fpext <8 x half> %b to <8 x float>
+  %mult = fmul <8 x float> %a.wide, %b.wide
+  %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult)
+  ret <4 x float> %partial.reduce
+}
+
+define <8 x half> @partial_reduce_half(<8 x half> %acc, <16 x half> %a) {
+; CHECK-LABEL: partial_reduce_half:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    fadd v0.8h, v0.8h, v2.8h
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <8 x half> @llvm.vector.partial.reduce.fadd(<8 x half> %acc, <16 x half> %a)
+  ret <8 x half> %partial.reduce
+}
+
+define <4 x float> @partial_reduce_float(<4 x float> %acc, <8 x float> %a) {
+; CHECK-LABEL: partial_reduce_float:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    fadd v0.4s, v0.4s, v2.4s
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %a)
+  ret <4 x float> %partial.reduce
+}
+
+define <2 x double> @partial_reduce_double(<2 x double> %acc, <4 x double> %a) {
+; CHECK-LABEL: partial_reduce_double:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fadd v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    fadd v0.2d, v0.2d, v2.2d
+; CHECK-NEXT:    ret
+entry:
+  %partial.reduce = call <2 x double> @llvm.vector.partial.reduce.fadd(<2 x double> %acc, <4 x double> %a)
+  ret <2 x double> %partial.reduce
+}


        


More information about the llvm-commits mailing list