[llvm] Add `llvm.vector.partial.reduction.fadd` intrinsic (PR #159776)
Damian Heaton via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 19 06:18:39 PDT 2025
https://github.com/dheaton-arm created https://github.com/llvm/llvm-project/pull/159776
With this intrinsic, and supporting SelectionDAG nodes, we can better make use of instructions such as AArch64's `FDOT`.
>From c5b22aa74c5d4e849c7fe441350b4fd99d65efed Mon Sep 17 00:00:00 2001
From: Damian Heaton <Damian.Heaton at arm.com>
Date: Fri, 19 Sep 2025 10:51:26 +0000
Subject: [PATCH] Add `llvm.vector.partial.reduction.fadd` intrinsic
With this intrinsic, and supporting SelectionDAG nodes, we can better make use of instructions such as AArch64's `FDOT`.
---
llvm/docs/LangRef.rst | 42 ++++++++++++
llvm/include/llvm/CodeGen/ISDOpcodes.h | 3 +-
llvm/include/llvm/CodeGen/TargetLowering.h | 4 +-
llvm/include/llvm/IR/Intrinsics.td | 4 ++
.../include/llvm/Target/TargetSelectionDAG.td | 2 +
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 22 +++++--
.../SelectionDAG/LegalizeVectorOps.cpp | 2 +
.../SelectionDAG/LegalizeVectorTypes.cpp | 2 +
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 3 +-
.../SelectionDAG/SelectionDAGBuilder.cpp | 13 ++++
.../SelectionDAG/SelectionDAGDumper.cpp | 2 +
.../CodeGen/SelectionDAG/TargetLowering.cpp | 21 +++---
.../Target/AArch64/AArch64ISelLowering.cpp | 21 +++++-
.../lib/Target/AArch64/AArch64SVEInstrInfo.td | 3 +
llvm/test/CodeGen/AArch64/sve2p1-fdot.ll | 66 +++++++++++++++++++
15 files changed, 190 insertions(+), 20 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 5fd0f6573bb97..069fcd29d808b 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20613,6 +20613,48 @@ performance, and an out-of-loop phase to calculate the final scalar result.
By avoiding the introduction of new ordering constraints, these intrinsics
enhance the ability to leverage a target's accumulation instructions.
+'``llvm.vector.partial.reduce.fadd.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+ declare <4 x f32> @llvm.vector.partial.reduce.fadd.v4f32.v4f32.v8f32(<4 x f32> %a, <8 x f32> %b)
+ declare <vscale x 4 x f32> @llvm.vector.partial.reduce.add.nxv4f32.nxv4f32.nxv8f32(<vscale x 4 x f32> %a, <vscale x 8 x f32> %b)
+
+Overview:
+"""""""""
+
+The '``llvm.vector.partial.reduce.fadd.*``' intrinsics reduce the
+concatenation of the two vector arguments down to the number of elements of the
+result vector type.
+
+Arguments:
+""""""""""
+
+The first argument is a floating-point vector with the same type as the result.
+
+The second argument is a vector with a length that is a known integer multiple
+of the result's type, while maintaining the same element type.
+
+Semantics:
+""""""""""
+
+Other than the reduction operator (e.g. add) the way in which the concatenated
+arguments is reduced is entirely unspecified. By their nature these intrinsics
+are not expected to be useful in isolation but instead implement the first phase
+of an overall reduction operation.
+
+The typical use case is loop vectorization where reductions are split into an
+in-loop phase, where maintaining an unordered vector result is important for
+performance, and an out-of-loop phase to calculate the final scalar result.
+
+By avoiding the introduction of new ordering constraints, these intrinsics
+enhance the ability to leverage a target's accumulation instructions.
+
'``llvm.experimental.vector.histogram.*``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index c76c83d84b3c7..83ee6ff677e3d 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1510,6 +1510,7 @@ enum NodeType {
PARTIAL_REDUCE_SMLA, // sext, sext
PARTIAL_REDUCE_UMLA, // zext, zext
PARTIAL_REDUCE_SUMLA, // sext, zext
+ PARTIAL_REDUCE_FMLA, // fpext, fpext
// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
@@ -1761,7 +1762,7 @@ LLVM_ABI CondCode getSetCCInverse(CondCode Operation, EVT Type);
inline bool isExtOpcode(unsigned Opcode) {
return Opcode == ISD::ANY_EXTEND || Opcode == ISD::ZERO_EXTEND ||
- Opcode == ISD::SIGN_EXTEND;
+ Opcode == ISD::SIGN_EXTEND || Opcode == ISD::FP_EXTEND;
}
inline bool isExtVecInRegOpcode(unsigned Opcode) {
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 46be271320fdd..26fb5f087d6ba 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1672,7 +1672,7 @@ class LLVM_ABI TargetLoweringBase {
LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
EVT InputVT) const {
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
- Opc == ISD::PARTIAL_REDUCE_SUMLA);
+ Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
InputVT.getSimpleVT().SimpleTy};
auto It = PartialReduceMLAActions.find(Key);
@@ -2774,7 +2774,7 @@ class LLVM_ABI TargetLoweringBase {
void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
LegalizeAction Action) {
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
- Opc == ISD::PARTIAL_REDUCE_SUMLA);
+ Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
assert(AccVT.isValid() && InputVT.isValid() &&
"setPartialReduceMLAAction types aren't valid");
PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 585371a6a4423..1ecfe284e05fa 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2801,6 +2801,10 @@ def int_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
[llvm_anyvector_ty, llvm_anyvector_ty],
[IntrNoMem]>;
+def int_vector_partial_reduce_fadd : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+ [llvm_anyfloat_ty, llvm_anyfloat_ty],
+ [IntrNoMem]>;
+
//===----------------- Pointer Authentication Intrinsics ------------------===//
//
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index ef88c9507c86d..1e36cc0a00505 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -523,6 +523,8 @@ def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
SDTPartialReduceMLA>;
def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
SDTPartialReduceMLA>;
+def partial_reduce_fmla : SDNode<"ISD::PARTIAL_REDUCE_FMLA",
+ SDTPartialReduceMLA>;
def fadd : SDNode<"ISD::FADD" , SDTFPBinOp, [SDNPCommutative]>;
def fsub : SDNode<"ISD::FSUB" , SDTFPBinOp>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 4b20b756f8a15..7347d77172054 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2040,6 +2040,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
@@ -12942,8 +12943,11 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue Op2 = N->getOperand(2);
APInt C;
- if (Op1->getOpcode() != ISD::MUL ||
- !ISD::isConstantSplatVector(Op2.getNode(), C) || !C.isOne())
+ if (!(Op1->getOpcode() == ISD::MUL &&
+ ISD::isConstantSplatVector(Op2.getNode(), C) && C.isOne()) &&
+ !(Op1->getOpcode() == ISD::FMUL &&
+ ISD::isConstantSplatVector(Op2.getNode(), C) &&
+ C == APFloat(1.0f).bitcastToAPInt().trunc(C.getBitWidth())))
return SDValue();
SDValue LHS = Op1->getOperand(0);
@@ -12998,6 +13002,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
std::swap(LHSExtOp, RHSExtOp);
+ } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
+ NewOpc = ISD::PARTIAL_REDUCE_FMLA;
} else
return SDValue();
// For a 2-stage extend the signedness of both of the extends must match
@@ -13033,22 +13039,26 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
APInt ConstantOne;
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
+ !(ConstantOne.isOne() ||
+ ConstantOne ==
+ APFloat(1.0f).bitcastToAPInt().trunc(ConstantOne.getBitWidth())))
return SDValue();
unsigned Op1Opcode = Op1.getOpcode();
if (!ISD::isExtOpcode(Op1Opcode))
return SDValue();
- bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+ bool Op1IsSigned = Op1Opcode != ISD::ZERO_EXTEND;
bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (Op1IsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
- unsigned NewOpcode =
- Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? ISD::PARTIAL_REDUCE_FMLA
+ : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
+ : ISD::PARTIAL_REDUCE_UMLA;
SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8e423c4f83b38..94751be5b7986 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -534,6 +534,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
Action =
TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
Node->getOperand(1).getValueType());
@@ -1243,6 +1244,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
return;
case ISD::VECREDUCE_SEQ_FADD:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index ff7cd665446cc..e6f19499b8f41 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1459,6 +1459,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
break;
case ISD::GET_ACTIVE_LANE_MASK:
@@ -3674,6 +3675,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
break;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 029eb025ff1de..bc082513786ef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8329,7 +8329,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
- case ISD::PARTIAL_REDUCE_SUMLA: {
+ case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA: {
[[maybe_unused]] EVT AccVT = N1.getValueType();
[[maybe_unused]] EVT Input1VT = N2.getValueType();
[[maybe_unused]] EVT Input2VT = N3.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 070d7978ce48f..448e3bbd02038 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8114,6 +8114,19 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
Input, DAG.getConstant(1, sdl, Input.getValueType())));
return;
}
+ case Intrinsic::vector_partial_reduce_fadd: {
+ if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
+ visitTargetIntrinsic(I, Intrinsic);
+ return;
+ }
+ SDValue Acc = getValue(I.getOperand(0));
+ SDValue Input = getValue(I.getOperand(1));
+ setValue(&I,
+ DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc,
+ Input,
+ DAG.getConstantFP(1.0f, sdl, Input.getValueType())));
+ return;
+ }
case Intrinsic::experimental_cttz_elts: {
auto DL = getCurSDLoc();
SDValue Op = getValue(I.getOperand(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 4b2a00c2e2cfa..cf5c269c20761 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -587,6 +587,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
return "partial_reduce_smla";
case ISD::PARTIAL_REDUCE_SUMLA:
return "partial_reduce_sumla";
+ case ISD::PARTIAL_REDUCE_FMLA:
+ return "partial_reduce_fmla";
case ISD::LOOP_DEPENDENCE_WAR_MASK:
return "loop_dep_war";
case ISD::LOOP_DEPENDENCE_RAW_MASK:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index fd6d20e146bb2..50aa41b77691d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12046,12 +12046,14 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
MulOpVT.getVectorElementCount());
- unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
- ? ISD::ZERO_EXTEND
- : ISD::SIGN_EXTEND;
- unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
- ? ISD::SIGN_EXTEND
- : ISD::ZERO_EXTEND;
+ unsigned ExtOpcLHS =
+ N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FP_EXTEND
+ : N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA ? ISD::ZERO_EXTEND
+ : ISD::SIGN_EXTEND;
+ unsigned ExtOpcRHS =
+ N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FP_EXTEND
+ : N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ? ISD::SIGN_EXTEND
+ : ISD::ZERO_EXTEND;
if (ExtMulOpVT != MulOpVT) {
MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
@@ -12060,7 +12062,7 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
SDValue Input = MulLHS;
APInt ConstantOne;
if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
+ !(ConstantOne.isOne() || ConstantOne == APFloat(1.0f).bitcastToAPInt()))
Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
unsigned Stride = AccVT.getVectorMinNumElements();
@@ -12071,10 +12073,13 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
for (unsigned I = 0; I < ScaleFactor; I++)
Subvectors.push_back(DAG.getExtractSubvector(DL, AccVT, Input, I * Stride));
+ unsigned FlatNode =
+ N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FADD : ISD::ADD;
+
// Flatten the subvector tree
while (Subvectors.size() > 1) {
Subvectors.push_back(
- DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]}));
+ DAG.getNode(FlatNode, DL, AccVT, {Subvectors[0], Subvectors[1]}));
Subvectors.pop_front();
Subvectors.pop_front();
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fc3efb072d57b..0ca596b634d11 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1919,6 +1919,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
}
}
+ // Handle floating-point partial reduction
+ if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
+ static const unsigned FMLAOps[] = {ISD::PARTIAL_REDUCE_FMLA};
+ setPartialReduceMLAAction(FMLAOps, MVT::nxv4f32, MVT::nxv8f16, Legal);
+ }
+
// Handle non-aliasing elements mask
if (Subtarget->hasSVE2() ||
(Subtarget->hasSME() && Subtarget->isStreaming())) {
@@ -2184,7 +2190,8 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
const IntrinsicInst *I) const {
- assert(I->getIntrinsicID() == Intrinsic::vector_partial_reduce_add &&
+ assert((I->getIntrinsicID() == Intrinsic::vector_partial_reduce_add ||
+ I->getIntrinsicID() == Intrinsic::vector_partial_reduce_fadd) &&
"Unexpected intrinsic!");
return true;
}
@@ -22519,7 +22526,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
- getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add &&
+ (getIntrinsicID(N) == Intrinsic::vector_partial_reduce_add ||
+ getIntrinsicID(N) == Intrinsic::vector_partial_reduce_fadd) &&
"Expected a partial reduction node");
bool Scalable = N->getValueType(0).isScalableVector();
@@ -22689,6 +22697,15 @@ static SDValue performIntrinsicCombine(SDNode *N,
N->getOperand(1), Input,
DAG.getConstant(1, DL, Input.getValueType()));
}
+ case Intrinsic::vector_partial_reduce_fadd: {
+ if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
+ return Dot;
+ SDLoc DL(N);
+ SDValue Input = N->getOperand(2);
+ return DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, DL, N->getValueType(0),
+ N->getOperand(1), Input,
+ DAG.getConstantFP(1.0f, DL, Input.getValueType()));
+ }
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 7fe4f7acdbd49..8ef69ad13abc5 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -4228,6 +4228,9 @@ defm FCLAMP_ZZZ : sve_fp_clamp<"fclamp", AArch64fclamp>;
defm FDOT_ZZZ_S : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, int_aarch64_sve_fdot_x2>;
defm FDOT_ZZZI_S : sve_float_dot_indexed<0b0, 0b00, ZPR16, ZPR3b16, "fdot", nxv8f16, int_aarch64_sve_fdot_lane_x2>;
+def : Pat<(nxv4f32 (partial_reduce_fmla nxv4f32:$Acc, nxv8f16:$LHS, nxv8f16:$RHS)),
+ (FDOT_ZZZ_S $Acc, $LHS, $RHS)>;
+
defm BFMLSLB_ZZZ_S : sve2_fp_mla_long<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb>;
defm BFMLSLT_ZZZ_S : sve2_fp_mla_long<0b111, "bfmlslt", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslt>;
defm BFMLSLB_ZZZI_S : sve2_fp_mla_long_by_indexed_elem<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb_lane>;
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
new file mode 100644
index 0000000000000..5bb1fae43392f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
+
+define <vscale x 4 x float> @fdot_wide_vl128(<vscale x 4 x float> %acc, <vscale x 8 x half> %a, <vscale x 8 x half> %b) {
+; CHECK-LABEL: fdot_wide_vl128:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fdot z0.s, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %b.wide = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+ %mult = fmul <vscale x 8 x float> %a.wide, %b.wide
+ %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %mult)
+ ret <vscale x 4 x float> %partial.reduce
+}
+
+define void @fdot_wide_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
+; CHECK-LABEL: fdot_wide_vl256:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: ld1h { z0.s }, p0/z, [x1]
+; CHECK-NEXT: ld1h { z1.s }, p0/z, [x2]
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x1, #1, mul vl]
+; CHECK-NEXT: ld1h { z3.s }, p0/z, [x2, #1, mul vl]
+; CHECK-NEXT: fcvt z0.s, p0/m, z0.h
+; CHECK-NEXT: fcvt z1.s, p0/m, z1.h
+; CHECK-NEXT: fcvt z2.s, p0/m, z2.h
+; CHECK-NEXT: fcvt z3.s, p0/m, z3.h
+; CHECK-NEXT: fmul z0.s, z0.s, z1.s
+; CHECK-NEXT: ldr z1, [x0]
+; CHECK-NEXT: fmul z2.s, z2.s, z3.s
+; CHECK-NEXT: fadd z0.s, z1.s, z0.s
+; CHECK-NEXT: fadd z0.s, z0.s, z2.s
+; CHECK-NEXT: str z0, [x0]
+; CHECK-NEXT: ret
+entry:
+ %acc = load <8 x float>, ptr %accptr
+ %a = load <16 x half>, ptr %aptr
+ %b = load <16 x half>, ptr %bptr
+ %a.wide = fpext <16 x half> %a to <16 x float>
+ %b.wide = fpext <16 x half> %b to <16 x float>
+ %mult = fmul <16 x float> %a.wide, %b.wide
+ %partial.reduce = call <8 x float> @llvm.vector.partial.reduce.fadd(<8 x float> %acc, <16 x float> %mult)
+ store <8 x float> %partial.reduce, ptr %accptr
+ ret void
+}
+
+define <4 x float> @fixed_fdot_wide(<4 x float> %acc, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: fixed_fdot_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fcvtl v3.4s, v1.4h
+; CHECK-NEXT: fcvtl v4.4s, v2.4h
+; CHECK-NEXT: fcvtl2 v1.4s, v1.8h
+; CHECK-NEXT: fcvtl2 v2.4s, v2.8h
+; CHECK-NEXT: fmul v3.4s, v3.4s, v4.4s
+; CHECK-NEXT: fmul v1.4s, v1.4s, v2.4s
+; CHECK-NEXT: fadd v0.4s, v0.4s, v3.4s
+; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: ret
+entry:
+ %a.wide = fpext <8 x half> %a to <8 x float>
+ %b.wide = fpext <8 x half> %b to <8 x float>
+ %mult = fmul <8 x float> %a.wide, %b.wide
+ %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult)
+ ret <4 x float> %partial.reduce
+}
More information about the llvm-commits
mailing list