[llvm] Add `llvm.vector.partial.reduction.fadd` intrinsic (PR #159776)

Damian Heaton via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 19 06:18:39 PDT 2025


https://github.com/dheaton-arm created https://github.com/llvm/llvm-project/pull/159776

With this intrinsic, and supporting SelectionDAG nodes, we can better make use of instructions such as AArch64's `FDOT`.

>From c5b22aa74c5d4e849c7fe441350b4fd99d65efed Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Fri, 19 Sep 2025 10:51:26 +0000
Subject: [PATCH] Add `llvm.vector.partial.reduction.fadd` intrinsic

With this intrinsic, and supporting SelectionDAG nodes, we can better make use of instructions such as AArch64's `FDOT`.
---
 llvm/docs/LangRef.rst                         | 42 ++++++++++++
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |  3 +-
 llvm/include/llvm/CodeGen/TargetLowering.h    |  4 +-
 llvm/include/llvm/IR/Intrinsics.td            |  4 ++
 .../include/llvm/Target/TargetSelectionDAG.td |  2 +
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 22 +++++--
 .../SelectionDAG/LegalizeVectorOps.cpp        |  2 +
 .../SelectionDAG/LegalizeVectorTypes.cpp      |  2 +
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  3 +-
 .../SelectionDAG/SelectionDAGBuilder.cpp      | 13 ++++
 .../SelectionDAG/SelectionDAGDumper.cpp       |  2 +
 .../CodeGen/SelectionDAG/TargetLowering.cpp   | 21 +++---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 21 +++++-
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td |  3 +
 llvm/test/CodeGen/AArch64/sve2p1-fdot.ll      | 66 +++++++++++++++++++
 15 files changed, 190 insertions(+), 20 deletions(-)
 create mode 100644 llvm/test/CodeGen/AArch64/sve2p1-fdot.ll

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 5fd0f6573bb97..069fcd29d808b 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20613,6 +20613,48 @@ performance, and an out-of-loop phase to calculate the final scalar result.
 By avoiding the introduction of new ordering constraints, these intrinsics
 enhance the ability to leverage a target's accumulation instructions.
 
+'``llvm.vector.partial.reduce.fadd.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+      declare <4 x f32> @llvm.vector.partial.reduce.fadd.v4f32.v4f32.v8f32(<4 x f32> %a, <8 x f32> %b)
+      declare <vscale x 4 x f32> @llvm.vector.partial.reduce.add.nxv4f32.nxv4f32.nxv8f32(<vscale x 4 x f32> %a, <vscale x 8 x f32> %b)
+
+Overview:
+"""""""""
+
+The '``llvm.vector.partial.reduce.fadd.*``' intrinsics reduce the
+concatenation of the two vector arguments down to the number of elements of the
+result vector type.
+
+Arguments:
+""""""""""
+
+The first argument is a floating-point vector with the same type as the result.
+
+The second argument is a vector with a length that is a known integer multiple
+of the result's type, while maintaining the same element type.
+
+Semantics:
+""""""""""
+
+Other than the reduction operator (e.g. add) the way in which the concatenated
+arguments is reduced is entirely unspecified. By their nature these intrinsics
+are not expected to be useful in isolation but instead implement the first phase
+of an overall reduction operation.
+
+The typical use case is loop vectorization where reductions are split into an
+in-loop phase, where maintaining an unordered vector result is important for
+performance, and an out-of-loop phase to calculate the final scalar result.
+
+By avoiding the introduction of new ordering constraints, these intrinsics
+enhance the ability to leverage a target's accumulation instructions.
+
 '``llvm.experimental.vector.histogram.*``' Intrinsic
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index c76c83d84b3c7..83ee6ff677e3d 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1510,6 +1510,7 @@ enum NodeType {
   PARTIAL_REDUCE_SMLA,  // sext, sext
   PARTIAL_REDUCE_UMLA,  // zext, zext
   PARTIAL_REDUCE_SUMLA, // sext, zext
+  PARTIAL_REDUCE_FMLA,  // fpext, fpext
 
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
@@ -1761,7 +1762,7 @@ LLVM_ABI CondCode getSetCCInverse(CondCode Operation, EVT Type);
 
 inline bool isExtOpcode(unsigned Opcode) {
   return Opcode == ISD::ANY_EXTEND || Opcode == ISD::ZERO_EXTEND ||
-         Opcode == ISD::SIGN_EXTEND;
+         Opcode == ISD::SIGN_EXTEND || Opcode == ISD::FP_EXTEND;
 }
 
 inline bool isExtVecInRegOpcode(unsigned Opcode) {
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 46be271320fdd..26fb5f087d6ba 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1672,7 +1672,7 @@ class LLVM_ABI TargetLoweringBase {
   LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
                                            EVT InputVT) const {
     assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
-           Opc == ISD::PARTIAL_REDUCE_SUMLA);
+           Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
     PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
                                     InputVT.getSimpleVT().SimpleTy};
     auto It = PartialReduceMLAActions.find(Key);
@@ -2774,7 +2774,7 @@ class LLVM_ABI TargetLoweringBase {
   void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
                                  LegalizeAction Action) {
     assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
-           Opc == ISD::PARTIAL_REDUCE_SUMLA);
+           Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
     assert(AccVT.isValid() && InputVT.isValid() &&
            "setPartialReduceMLAAction types aren't valid");
     PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 585371a6a4423..1ecfe284e05fa 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2801,6 +2801,10 @@ def int_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
                                                           [llvm_anyvector_ty, llvm_anyvector_ty],
                                                           [IntrNoMem]>;
 
+def int_vector_partial_reduce_fadd : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                                                                        [llvm_anyfloat_ty, llvm_anyfloat_ty],
+                                                                        [IntrNoMem]>;
+
 //===----------------- Pointer Authentication Intrinsics ------------------===//
 //
 
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index ef88c9507c86d..1e36cc0a00505 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -523,6 +523,8 @@ def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
                                  SDTPartialReduceMLA>;
 def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
                                  SDTPartialReduceMLA>;
+def partial_reduce_fmla : SDNode<"ISD::PARTIAL_REDUCE_FMLA",
+                                 SDTPartialReduceMLA>;
 
 def fadd       : SDNode<"ISD::FADD"       , SDTFPBinOp, [SDNPCommutative]>;
 def fsub       : SDNode<"ISD::FSUB"       , SDTFPBinOp>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 4b20b756f8a15..7347d77172054 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2040,6 +2040,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
                                 return visitPARTIAL_REDUCE_MLA(N);
   case ISD::VECTOR_COMPRESS:    return visitVECTOR_COMPRESS(N);
   case ISD::LIFETIME_END:       return visitLIFETIME_END(N);
@@ -12942,8 +12943,11 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue Op2 = N->getOperand(2);
 
   APInt C;
-  if (Op1->getOpcode() != ISD::MUL ||
-      !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
+  if (!(Op1->getOpcode() == ISD::MUL &&
+        ISD::isConstantSplatVector(Op2.getNode(), C) && C.isOne()) &&
+      !(Op1->getOpcode() == ISD::FMUL &&
+        ISD::isConstantSplatVector(Op2.getNode(), C) &&
+        C == APFloat(1.0f).bitcastToAPInt().trunc(C.getBitWidth())))
     return SDValue();
 
   SDValue LHS = Op1->getOperand(0);
@@ -12998,6 +13002,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
     NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
     std::swap(LHSExtOp, RHSExtOp);
+  } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
+    NewOpc = ISD::PARTIAL_REDUCE_FMLA;
   } else
     return SDValue();
   // For a 2-stage extend the signedness of both of the extends must match
@@ -13033,22 +13039,26 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
 
   APInt ConstantOne;
   if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !(ConstantOne.isOne() ||
+        ConstantOne ==
+            APFloat(1.0f).bitcastToAPInt().trunc(ConstantOne.getBitWidth())))
     return SDValue();
 
   unsigned Op1Opcode = Op1.getOpcode();
   if (!ISD::isExtOpcode(Op1Opcode))
     return SDValue();
 
-  bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+  bool Op1IsSigned = Op1Opcode != ISD::ZERO_EXTEND;
   bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
   if (Op1IsSigned != NodeIsSigned &&
       Op1.getValueType().getVectorElementType() != AccElemVT)
     return SDValue();
 
-  unsigned NewOpcode =
-      Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+  unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+                           ? ISD::PARTIAL_REDUCE_FMLA
+                       : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
+                                     : ISD::PARTIAL_REDUCE_UMLA;
 
   SDValue UnextOp1 = Op1.getOperand(0);
   EVT UnextOp1VT = UnextOp1.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8e423c4f83b38..94751be5b7986 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -534,6 +534,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Action =
         TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
                                       Node->getOperand(1).getValueType());
@@ -1243,6 +1244,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
     return;
   case ISD::VECREDUCE_SEQ_FADD:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index ff7cd665446cc..e6f19499b8f41 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1459,6 +1459,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
     break;
   case ISD::GET_ACTIVE_LANE_MASK:
@@ -3674,6 +3675,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
   case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA:
     Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
     break;
   }
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 029eb025ff1de..bc082513786ef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8329,7 +8329,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
   }
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
-  case ISD::PARTIAL_REDUCE_SUMLA: {
+  case ISD::PARTIAL_REDUCE_SUMLA:
+  case ISD::PARTIAL_REDUCE_FMLA: {
     [[maybe_unused]] EVT AccVT = N1.getValueType();
     [[maybe_unused]] EVT Input1VT = N2.getValueType();
     [[maybe_unused]] EVT Input2VT = N3.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 070d7978ce48f..448e3bbd02038 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8114,6 +8114,19 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
                          Input, DAG.getConstant(1, sdl, Input.getValueType())));
     return;
   }
+  case Intrinsic::vector_partial_reduce_fadd: {
+    if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
+      visitTargetIntrinsic(I, Intrinsic);
+      return;
+    }
+    SDValue Acc = getValue(I.getOperand(0));
+    SDValue Input = getValue(I.getOperand(1));
+    setValue(&I,
+             DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc,
+                         Input,
+                         DAG.getConstantFP(1.0f, sdl, Input.getValueType())));
+    return;
+  }
   case Intrinsic::experimental_cttz_elts: {
     auto DL = getCurSDLoc();
     SDValue Op = getValue(I.getOperand(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 4b2a00c2e2cfa..cf5c269c20761 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -587,6 +587,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
     return "partial_reduce_smla";
   case ISD::PARTIAL_REDUCE_SUMLA:
     return "partial_reduce_sumla";
+  case ISD::PARTIAL_REDUCE_FMLA:
+    return "partial_reduce_fmla";
   case ISD::LOOP_DEPENDENCE_WAR_MASK:
     return "loop_dep_war";
   case ISD::LOOP_DEPENDENCE_RAW_MASK:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index fd6d20e146bb2..50aa41b77691d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12046,12 +12046,14 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
       EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
                        MulOpVT.getVectorElementCount());
 
-  unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
-                      ? ISD::ZERO_EXTEND
-                      : ISD::SIGN_EXTEND;
-  unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
-                      ? ISD::SIGN_EXTEND
-                      : ISD::ZERO_EXTEND;
+  unsigned ExtOpcLHS =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA   ? ISD::FP_EXTEND
+      : N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA ? ISD::ZERO_EXTEND
+                                                   : ISD::SIGN_EXTEND;
+  unsigned ExtOpcRHS =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA   ? ISD::FP_EXTEND
+      : N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ? ISD::SIGN_EXTEND
+                                                   : ISD::ZERO_EXTEND;
 
   if (ExtMulOpVT != MulOpVT) {
     MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
@@ -12060,7 +12062,7 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   SDValue Input = MulLHS;
   APInt ConstantOne;
   if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
-      !ConstantOne.isOne())
+      !(ConstantOne.isOne() || ConstantOne == APFloat(1.0f).bitcastToAPInt()))
     Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
 
   unsigned Stride = AccVT.getVectorMinNumElements();
@@ -12071,10 +12073,13 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
   for (unsigned I = 0; I < ScaleFactor; I++)
     Subvectors.push_back(DAG.getExtractSubvector(DL, AccVT, Input, I * Stride));
 
+  unsigned FlatNode =
+      N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FADD : ISD::ADD;
+
   // Flatten the subvector tree
   while (Subvectors.size() > 1) {
     Subvectors.push_back(
-        DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]}));
+        DAG.getNode(FlatNode, DL, AccVT, {Subvectors[0], Subvectors[1]}));
     Subvectors.pop_front();
     Subvectors.pop_front();
   }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fc3efb072d57b..0ca596b634d11 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1919,6 +1919,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     }
   }
 
+  // Handle floating-point partial reduction
+  if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
+    static const unsigned FMLAOps[] = {ISD::PARTIAL_REDUCE_FMLA};
+    setPartialReduceMLAAction(FMLAOps, MVT::nxv4f32, MVT::nxv8f16, Legal);
+  }
+
   // Handle non-aliasing elements mask
   if (Subtarget->hasSVE2() ||
       (Subtarget->hasSME() && Subtarget->isStreaming())) {
@@ -2184,7 +2190,8 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
 
 bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     const IntrinsicInst *I) const {
-  assert(I->getIntrinsicID() == Intrinsic::vector_partial_reduce_add &&
+  assert((I->getIntrinsicID() == Intrinsic::vector_partial_reduce_add ||
+          I->getIntrinsicID() == Intrinsic::vector_partial_reduce_fadd) &&
          "Unexpected intrinsic!");
   return true;
 }
@@ -22519,7 +22526,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
                                       SelectionDAG &DAG) {
 
   assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add &&
+         (getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add ||
+          getIntrinsicID(N) == Intrinsic::vector_partial_reduce_fadd) &&
          "Expected a partial reduction node");
 
   bool Scalable = N->getValueType(0).isScalableVector();
@@ -22689,6 +22697,15 @@ static SDValue performIntrinsicCombine(SDNode *N,
                        N->getOperand(1), Input,
                        DAG.getConstant(1, DL, Input.getValueType()));
   }
+  case Intrinsic::vector_partial_reduce_fadd: {
+    if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
+      return Dot;
+    SDLoc DL(N);
+    SDValue Input = N->getOperand(2);
+    return DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, DL, N->getValueType(0),
+                       N->getOperand(1), Input,
+                       DAG.getConstantFP(1.0f, DL, Input.getValueType()));
+  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 7fe4f7acdbd49..8ef69ad13abc5 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4228,6 +4228,9 @@ defm FCLAMP_ZZZ : sve_fp_clamp<"fclamp", AArch64fclamp>;
 defm FDOT_ZZZ_S  : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, int_aarch64_sve_fdot_x2>;
 defm FDOT_ZZZI_S : sve_float_dot_indexed<0b0, 0b00, ZPR16, ZPR3b16, "fdot", nxv8f16, int_aarch64_sve_fdot_lane_x2>;
 
+def : Pat<(nxv4f32 (partial_reduce_fmla nxv4f32:$Acc, nxv8f16:$LHS, nxv8f16:$RHS)),
+          (FDOT_ZZZ_S $Acc, $LHS, $RHS)>;
+
 defm BFMLSLB_ZZZ_S : sve2_fp_mla_long<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb>;
 defm BFMLSLT_ZZZ_S : sve2_fp_mla_long<0b111, "bfmlslt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslt>;
 defm BFMLSLB_ZZZI_S : sve2_fp_mla_long_by_indexed_elem<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb_lane>;
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
new file mode 100644
index 0000000000000..5bb1fae43392f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
+
+define <vscale x 4 x float> @fdot_wide_vl128(<vscale x 4 x float> %acc, <vscale x 8 x half> %a, <vscale x 8 x half> %b) {
+; CHECK-LABEL: fdot_wide_vl128:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fdot z0.s, z1.h, z2.h
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+  %b.wide = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+  %mult = fmul <vscale x 8 x float> %a.wide, %b.wide
+  %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %mult)
+  ret <vscale x 4 x float> %partial.reduce
+}
+
+define void @fdot_wide_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
+; CHECK-LABEL: fdot_wide_vl256:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    ld1h { z0.s }, p0/z, [x1]
+; CHECK-NEXT:    ld1h { z1.s }, p0/z, [x2]
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x1, #1, mul vl]
+; CHECK-NEXT:    ld1h { z3.s }, p0/z, [x2, #1, mul vl]
+; CHECK-NEXT:    fcvt z0.s, p0/m, z0.h
+; CHECK-NEXT:    fcvt z1.s, p0/m, z1.h
+; CHECK-NEXT:    fcvt z2.s, p0/m, z2.h
+; CHECK-NEXT:    fcvt z3.s, p0/m, z3.h
+; CHECK-NEXT:    fmul z0.s, z0.s, z1.s
+; CHECK-NEXT:    ldr z1, [x0]
+; CHECK-NEXT:    fmul z2.s, z2.s, z3.s
+; CHECK-NEXT:    fadd z0.s, z1.s, z0.s
+; CHECK-NEXT:    fadd z0.s, z0.s, z2.s
+; CHECK-NEXT:    str z0, [x0]
+; CHECK-NEXT:    ret
+entry:
+  %acc = load <8 x float>, ptr %accptr
+  %a = load <16 x half>, ptr %aptr
+  %b = load <16 x half>, ptr %bptr
+  %a.wide = fpext <16 x half> %a to <16 x float>
+  %b.wide = fpext <16 x half> %b to <16 x float>
+  %mult = fmul <16 x float> %a.wide, %b.wide
+  %partial.reduce = call <8 x float> @llvm.vector.partial.reduce.fadd(<8 x float> %acc, <16 x float> %mult)
+  store <8 x float> %partial.reduce, ptr %accptr
+  ret void
+}
+
+define <4 x float> @fixed_fdot_wide(<4 x float> %acc, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: fixed_fdot_wide:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fcvtl v3.4s, v1.4h
+; CHECK-NEXT:    fcvtl v4.4s, v2.4h
+; CHECK-NEXT:    fcvtl2 v1.4s, v1.8h
+; CHECK-NEXT:    fcvtl2 v2.4s, v2.8h
+; CHECK-NEXT:    fmul v3.4s, v3.4s, v4.4s
+; CHECK-NEXT:    fmul v1.4s, v1.4s, v2.4s
+; CHECK-NEXT:    fadd v0.4s, v0.4s, v3.4s
+; CHECK-NEXT:    fadd v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    ret
+entry:
+  %a.wide = fpext <8 x half> %a to <8 x float>
+  %b.wide = fpext <8 x half> %b to <8 x float>
+  %mult = fmul <8 x float> %a.wide, %b.wide
+  %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult)
+  ret <4 x float> %partial.reduce
+}



More information about the llvm-commits mailing list