[llvm] 70f4b59 - Add `llvm.vector.partial.reduce.fadd` intrinsic (#159776)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 7 07:36:58 PST 2025
Author: Damian Heaton
Date: 2025-11-07T15:36:54Z
New Revision: 70f4b596cf453369ce4111c23e7e93633e5fe4b1
URL: https://github.com/llvm/llvm-project/commit/70f4b596cf453369ce4111c23e7e93633e5fe4b1
DIFF: https://github.com/llvm/llvm-project/commit/70f4b596cf453369ce4111c23e7e93633e5fe4b1.diff
LOG: Add `llvm.vector.partial.reduce.fadd` intrinsic (#159776)
With this intrinsic, and supporting SelectionDAG nodes, we can better
make use of instructions such as AArch64's `FDOT`.
Added:
llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll
Modified:
llvm/docs/LangRef.rst
llvm/include/llvm/CodeGen/ISDOpcodes.h
llvm/include/llvm/CodeGen/SelectionDAGNodes.h
llvm/include/llvm/CodeGen/TargetLowering.h
llvm/include/llvm/IR/Intrinsics.td
llvm/include/llvm/Target/TargetSelectionDAG.td
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/lib/IR/Verifier.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Removed:
################################################################################
diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index bd0337f25c758..ab085ca0b1499 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -20368,6 +20368,77 @@ Arguments:
""""""""""
The argument to this intrinsic must be a vector of floating-point values.
+Vector Partial Reduction Intrinsics
+-----------------------------------
+
+Partial reductions of vectors can be expressed using the intrinsics described in
+this section. Each one reduces the concatenation of the two vector arguments
+down to the number of elements of the result vector type.
+
+Other than the reduction operator (e.g. add, fadd), the way in which the
+concatenated arguments is reduced is entirely unspecified. By their nature these
+intrinsics are not expected to be useful in isolation but can instead be used to
+implement the first phase of an overall reduction operation.
+
+The typical use case is loop vectorization where reductions are split into an
+in-loop phase, where maintaining an unordered vector result is important for
+performance, and an out-of-loop phase is required to calculate the final scalar
+result.
+
+By avoiding the introduction of new ordering constraints, these intrinsics
+enhance the ability to leverage a target's accumulation instructions.
+
+'``llvm.vector.partial.reduce.add.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+ declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v8i32(<4 x i32> %a, <8 x i32> %b)
+ declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v16i32(<4 x i32> %a, <16 x i32> %b)
+ declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv8i32(<vscale x 4 x i32> %a, <vscale x 8 x i32> %b)
+ declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv16i32(<vscale x 4 x i32> %a, <vscale x 16 x i32> %b)
+
+Arguments:
+""""""""""
+
+The first argument is an integer vector with the same type as the result.
+
+The second argument is a vector with a length that is a known integer multiple
+of the result's type, while maintaining the same element type.
+
+'``llvm.vector.partial.reduce.fadd.*``' Intrinsic
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+This is an overloaded intrinsic.
+
+::
+
+ declare <4 x f32> @llvm.vector.partial.reduce.fadd.v4f32.v8f32(<4 x f32> %a, <8 x f32> %b)
+ declare <vscale x 4 x f32> @llvm.vector.partial.reduce.fadd.nxv4f32.nxv8f32(<vscale x 4 x f32> %a, <vscale x 8 x f32> %b)
+
+Arguments:
+""""""""""
+
+The first argument is a floating-point vector with the same type as the result.
+
+The second argument is a vector with a length that is a known integer multiple
+of the result's type, while maintaining the same element type.
+
+Semantics:
+""""""""""
+
+As the way in which the arguments to this floating-point intrinsic are reduced
+is unspecified, this intrinsic will assume floating-point reassociation and
+contraction can be leveraged to implement the reduction, which may result in
+variations to the results due to reordering or by lowering to
diff erent
+instructions (including combining multiple instructions into a single one).
+
'``llvm.vector.insert``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -20741,50 +20812,6 @@ Note that it has the following implications:
- If ``%cnt`` is non-zero, the return value is non-zero as well.
- If ``%cnt`` is less than or equal to ``%max_lanes``, the return value is equal to ``%cnt``.
-'``llvm.vector.partial.reduce.add.*``' Intrinsic
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-Syntax:
-"""""""
-This is an overloaded intrinsic.
-
-::
-
- declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v8i32(<4 x i32> %a, <8 x i32> %b)
- declare <4 x i32> @llvm.vector.partial.reduce.add.v4i32.v4i32.v16i32(<4 x i32> %a, <16 x i32> %b)
- declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv8i32(<vscale x 4 x i32> %a, <vscale x 8 x i32> %b)
- declare <vscale x 4 x i32> @llvm.vector.partial.reduce.add.nxv4i32.nxv4i32.nxv16i32(<vscale x 4 x i32> %a, <vscale x 16 x i32> %b)
-
-Overview:
-"""""""""
-
-The '``llvm.vector.partial.reduce.add.*``' intrinsics reduce the
-concatenation of the two vector arguments down to the number of elements of the
-result vector type.
-
-Arguments:
-""""""""""
-
-The first argument is an integer vector with the same type as the result.
-
-The second argument is a vector with a length that is a known integer multiple
-of the result's type, while maintaining the same element type.
-
-Semantics:
-""""""""""
-
-Other than the reduction operator (e.g., add) the way in which the concatenated
-arguments is reduced is entirely unspecified. By their nature these intrinsics
-are not expected to be useful in isolation but instead implement the first phase
-of an overall reduction operation.
-
-The typical use case is loop vectorization where reductions are split into an
-in-loop phase, where maintaining an unordered vector result is important for
-performance, and an out-of-loop phase to calculate the final scalar result.
-
-By avoiding the introduction of new ordering constraints, these intrinsics
-enhance the ability to leverage a target's accumulation instructions.
-
'``llvm.experimental.vector.histogram.*``' Intrinsic
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index af60c04513ea7..b3a2ced70e628 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1516,6 +1516,7 @@ enum NodeType {
PARTIAL_REDUCE_SMLA, // sext, sext
PARTIAL_REDUCE_UMLA, // zext, zext
PARTIAL_REDUCE_SUMLA, // sext, zext
+ PARTIAL_REDUCE_FMLA, // fpext, fpext
// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 1759463ea7965..cd466dceb900f 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1949,6 +1949,10 @@ LLVM_ABI bool isNullOrNullSplat(SDValue V, bool AllowUndefs = false);
/// be zero.
LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
+/// Return true if the value is a constant floating-point value, or a splatted
+/// vector of a constant floating-point value, of 1.0 (with no undefs).
+LLVM_ABI bool isOneOrOneSplatFP(SDValue V, bool AllowUndefs = false);
+
/// Return true if the value is a constant -1 integer or a splatted vector of a
/// constant -1 integer (with no undefs).
/// Does not permit build vector implicit truncation.
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 2550c2bee5f71..98565f423df3e 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1679,7 +1679,7 @@ class LLVM_ABI TargetLoweringBase {
LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
EVT InputVT) const {
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
- Opc == ISD::PARTIAL_REDUCE_SUMLA);
+ Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
InputVT.getSimpleVT().SimpleTy};
auto It = PartialReduceMLAActions.find(Key);
@@ -2793,7 +2793,7 @@ class LLVM_ABI TargetLoweringBase {
void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
LegalizeAction Action) {
assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA ||
- Opc == ISD::PARTIAL_REDUCE_SUMLA);
+ Opc == ISD::PARTIAL_REDUCE_SUMLA || Opc == ISD::PARTIAL_REDUCE_FMLA);
assert(AccVT.isValid() && InputVT.isValid() &&
"setPartialReduceMLAAction types aren't valid");
PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 520130c6dcb4c..07aa2faffa7c5 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -2813,6 +2813,10 @@ def int_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
[IntrNoMem,
IntrSpeculatable]>;
+def int_vector_partial_reduce_fadd : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+ [llvm_anyfloat_ty, llvm_anyfloat_ty],
+ [IntrNoMem]>;
+
//===----------------- Pointer Authentication Intrinsics ------------------===//
//
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 07a858fd682fc..a9750a5ab03f9 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -527,6 +527,8 @@ def partial_reduce_smla : SDNode<"ISD::PARTIAL_REDUCE_SMLA",
SDTPartialReduceMLA>;
def partial_reduce_sumla : SDNode<"ISD::PARTIAL_REDUCE_SUMLA",
SDTPartialReduceMLA>;
+def partial_reduce_fmla : SDNode<"ISD::PARTIAL_REDUCE_FMLA",
+ SDTPartialReduceMLA>;
def fadd : SDNode<"ISD::FADD" , SDTFPBinOp, [SDNPCommutative]>;
def fsub : SDNode<"ISD::FSUB" , SDTFPBinOp>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 10daca57558ff..f144f17d5a8f2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2042,6 +2042,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
return visitPARTIAL_REDUCE_MLA(N);
case ISD::VECTOR_COMPRESS: return visitVECTOR_COMPRESS(N);
case ISD::LIFETIME_END: return visitLIFETIME_END(N);
@@ -13006,6 +13007,9 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
//
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
+//
+// partial_reduce_fmla(acc, fmul(fpext(a), fpext(b)), splat(1.0))
+// -> partial_reduce_fmla(acc, a, b)
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDLoc DL(N);
auto *Context = DAG.getContext();
@@ -13014,7 +13018,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
SDValue Op2 = N->getOperand(2);
unsigned Opc = Op1->getOpcode();
- if (Opc != ISD::MUL && Opc != ISD::SHL)
+ if (Opc != ISD::MUL && Opc != ISD::FMUL && Opc != ISD::SHL)
return SDValue();
SDValue LHS = Op1->getOperand(0);
@@ -13033,13 +13037,16 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
Opc = ISD::MUL;
}
- APInt C;
- if (Opc != ISD::MUL || !ISD::isConstantSplatVector(Op2.getNode(), C) ||
- !C.isOne())
+ if (!(Opc == ISD::MUL && llvm::isOneOrOneSplat(Op2)) &&
+ !(Opc == ISD::FMUL && llvm::isOneOrOneSplatFP(Op2)))
return SDValue();
+ auto IsIntOrFPExtOpcode = [](unsigned int Opcode) {
+ return (ISD::isExtOpcode(Opcode) || Opcode == ISD::FP_EXTEND);
+ };
+
unsigned LHSOpcode = LHS->getOpcode();
- if (!ISD::isExtOpcode(LHSOpcode))
+ if (!IsIntOrFPExtOpcode(LHSOpcode))
return SDValue();
SDValue LHSExtOp = LHS->getOperand(0);
@@ -13047,6 +13054,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
// -> partial_reduce_*mla(acc, x, C)
+ APInt C;
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
// TODO: Make use of partial_reduce_sumla here
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
@@ -13071,7 +13079,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
}
unsigned RHSOpcode = RHS->getOpcode();
- if (!ISD::isExtOpcode(RHSOpcode))
+ if (!IsIntOrFPExtOpcode(RHSOpcode))
return SDValue();
SDValue RHSExtOp = RHS->getOperand(0);
@@ -13088,6 +13096,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
else if (LHSOpcode == ISD::ZERO_EXTEND && RHSOpcode == ISD::SIGN_EXTEND) {
NewOpc = ISD::PARTIAL_REDUCE_SUMLA;
std::swap(LHSExtOp, RHSExtOp);
+ } else if (LHSOpcode == ISD::FP_EXTEND && RHSOpcode == ISD::FP_EXTEND) {
+ NewOpc = ISD::PARTIAL_REDUCE_FMLA;
} else
return SDValue();
// For a 2-stage extend the signedness of both of the extends must match
@@ -13115,30 +13125,33 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
// partial.reduce.sumla(acc, sext(op), splat(1))
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
+// partial.reduce.fmla(acc, fpext(op), splat(1.0))
+// -> partial.reduce.fmla(acc, op, splat(1.0))
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
SDLoc DL(N);
SDValue Acc = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Op2 = N->getOperand(2);
- APInt ConstantOne;
- if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
+ if (!llvm::isOneOrOneSplat(Op2) && !llvm::isOneOrOneSplatFP(Op2))
return SDValue();
unsigned Op1Opcode = Op1.getOpcode();
- if (!ISD::isExtOpcode(Op1Opcode))
+ if (!ISD::isExtOpcode(Op1Opcode) && Op1Opcode != ISD::FP_EXTEND)
return SDValue();
- bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
+ bool Op1IsSigned =
+ Op1Opcode == ISD::SIGN_EXTEND || Op1Opcode == ISD::FP_EXTEND;
bool NodeIsSigned = N->getOpcode() != ISD::PARTIAL_REDUCE_UMLA;
EVT AccElemVT = Acc.getValueType().getVectorElementType();
if (Op1IsSigned != NodeIsSigned &&
Op1.getValueType().getVectorElementType() != AccElemVT)
return SDValue();
- unsigned NewOpcode =
- Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ unsigned NewOpcode = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? ISD::PARTIAL_REDUCE_FMLA
+ : Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA
+ : ISD::PARTIAL_REDUCE_UMLA;
SDValue UnextOp1 = Op1.getOperand(0);
EVT UnextOp1VT = UnextOp1.getValueType();
@@ -13148,8 +13161,12 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
return SDValue();
+ SDValue Constant = N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA
+ ? DAG.getConstantFP(1, DL, UnextOp1VT)
+ : DAG.getConstant(1, DL, UnextOp1VT);
+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
- DAG.getConstant(1, DL, UnextOp1VT));
+ Constant);
}
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 8e423c4f83b38..94751be5b7986 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -534,6 +534,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
Action =
TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
Node->getOperand(1).getValueType());
@@ -1243,6 +1244,7 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
Results.push_back(TLI.expandPartialReduceMLA(Node, DAG));
return;
case ISD::VECREDUCE_SEQ_FADD:
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index bb4a8d9967f94..dd5c011bfe784 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1474,6 +1474,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
break;
case ISD::GET_ACTIVE_LANE_MASK:
@@ -3689,6 +3690,7 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
Res = SplitVecOp_PARTIAL_REDUCE_MLA(N);
break;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 0a06752ae1846..bbc1d734cfef5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -8404,7 +8404,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
}
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
- case ISD::PARTIAL_REDUCE_SUMLA: {
+ case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA: {
[[maybe_unused]] EVT AccVT = N1.getValueType();
[[maybe_unused]] EVT Input1VT = N2.getValueType();
[[maybe_unused]] EVT Input2VT = N3.getValueType();
@@ -13064,6 +13065,11 @@ bool llvm::isOneOrOneSplat(SDValue N, bool AllowUndefs) {
return C && C->isOne();
}
+bool llvm::isOneOrOneSplatFP(SDValue N, bool AllowUndefs) {
+ ConstantFPSDNode *C = isConstOrConstSplatFP(N, AllowUndefs);
+ return C && C->isExactlyValue(1.0);
+}
+
bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
N = peekThroughBitcasts(N);
unsigned BitWidth = N.getScalarValueSizeInBits();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 2f598b2f03e24..88b0809b767b5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8187,6 +8187,14 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
Input, DAG.getConstant(1, sdl, Input.getValueType())));
return;
}
+ case Intrinsic::vector_partial_reduce_fadd: {
+ SDValue Acc = getValue(I.getOperand(0));
+ SDValue Input = getValue(I.getOperand(1));
+ setValue(&I, DAG.getNode(
+ ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc,
+ Input, DAG.getConstantFP(1.0, sdl, Input.getValueType())));
+ return;
+ }
case Intrinsic::experimental_cttz_elts: {
auto DL = getCurSDLoc();
SDValue Op = getValue(I.getOperand(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index d3e162879304d..ec5edd5f13978 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -590,6 +590,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
return "partial_reduce_smla";
case ISD::PARTIAL_REDUCE_SUMLA:
return "partial_reduce_sumla";
+ case ISD::PARTIAL_REDUCE_FMLA:
+ return "partial_reduce_fmla";
case ISD::LOOP_DEPENDENCE_WAR_MASK:
return "loop_dep_war";
case ISD::LOOP_DEPENDENCE_RAW_MASK:
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 9bdf82210fed1..b51d6649af2ec 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -12074,22 +12074,32 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
MulOpVT.getVectorElementCount());
- unsigned ExtOpcLHS = N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA
- ? ISD::ZERO_EXTEND
- : ISD::SIGN_EXTEND;
- unsigned ExtOpcRHS = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA
- ? ISD::SIGN_EXTEND
- : ISD::ZERO_EXTEND;
+ unsigned ExtOpcLHS, ExtOpcRHS;
+ switch (N->getOpcode()) {
+ default:
+ llvm_unreachable("Unexpected opcode");
+ case ISD::PARTIAL_REDUCE_UMLA:
+ ExtOpcLHS = ExtOpcRHS = ISD::ZERO_EXTEND;
+ break;
+ case ISD::PARTIAL_REDUCE_SMLA:
+ ExtOpcLHS = ExtOpcRHS = ISD::SIGN_EXTEND;
+ break;
+ case ISD::PARTIAL_REDUCE_FMLA:
+ ExtOpcLHS = ExtOpcRHS = ISD::FP_EXTEND;
+ break;
+ }
if (ExtMulOpVT != MulOpVT) {
MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
MulRHS = DAG.getNode(ExtOpcRHS, DL, ExtMulOpVT, MulRHS);
}
SDValue Input = MulLHS;
- APInt ConstantOne;
- if (!ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) ||
- !ConstantOne.isOne())
+ if (N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA) {
+ if (!llvm::isOneOrOneSplatFP(MulRHS))
+ Input = DAG.getNode(ISD::FMUL, DL, ExtMulOpVT, MulLHS, MulRHS);
+ } else if (!llvm::isOneOrOneSplat(MulRHS)) {
Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
+ }
unsigned Stride = AccVT.getVectorMinNumElements();
unsigned ScaleFactor = MulOpVT.getVectorMinNumElements() / Stride;
@@ -12099,10 +12109,13 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
for (unsigned I = 0; I < ScaleFactor; I++)
Subvectors.push_back(DAG.getExtractSubvector(DL, AccVT, Input, I * Stride));
+ unsigned FlatNode =
+ N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FADD : ISD::ADD;
+
// Flatten the subvector tree
while (Subvectors.size() > 1) {
Subvectors.push_back(
- DAG.getNode(ISD::ADD, DL, AccVT, {Subvectors[0], Subvectors[1]}));
+ DAG.getNode(FlatNode, DL, AccVT, {Subvectors[0], Subvectors[1]}));
Subvectors.pop_front();
Subvectors.pop_front();
}
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index f1e473a78e3ac..59eb870798235 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -6583,6 +6583,7 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
}
break;
}
+ case Intrinsic::vector_partial_reduce_fadd:
case Intrinsic::vector_partial_reduce_add: {
VectorType *AccTy = cast<VectorType>(Call.getArgOperand(0)->getType());
VectorType *VecTy = cast<VectorType>(Call.getArgOperand(1)->getType());
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 40e6400756c7c..c8a038fa99b30 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1916,6 +1916,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal);
setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal);
}
+
+ // Handle floating-point partial reduction
+ if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, MVT::nxv4f32,
+ MVT::nxv8f16, Legal);
+ }
}
// Handle non-aliasing elements mask
@@ -2283,6 +2289,11 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
}
+ if (Subtarget->hasSVE2p1() && VT.getVectorElementType() == MVT::f32) {
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, VT,
+ MVT::getVectorVT(MVT::f16, NumElts * 2), Custom);
+ }
+
// Lower fixed length vector operations to scalable equivalents.
setOperationAction(ISD::ABDS, VT, Default);
setOperationAction(ISD::ABDU, VT, Default);
@@ -7875,6 +7886,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
case ISD::PARTIAL_REDUCE_SMLA:
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SUMLA:
+ case ISD::PARTIAL_REDUCE_FMLA:
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 3b268dcbca600..e1f43867bbe5b 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -375,6 +375,11 @@ def AArch64fclamp : PatFrags<(ops node:$Zd, node:$Zn, node:$Zm),
node:$Zm)
]>;
+def AArch64fdot : PatFrags<(ops node:$Zd, node:$Zn, node:$Zm),
+ [(int_aarch64_sve_fdot_x2 node:$Zd, node:$Zn, node:$Zm),
+ (partial_reduce_fmla node:$Zd, node:$Zn, node:$Zm)
+ ]>;
+
def SDT_AArch64FCVT : SDTypeProfile<1, 3, [
SDTCisVec<0>, SDTCisVec<1>, SDTCisVec<2>, SDTCisVec<3>,
SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0,1>, SDTCisSameAs<0,3>
@@ -4251,7 +4256,7 @@ defm PSEL_PPPRI : sve2_int_perm_sel_p<"psel", int_aarch64_sve_psel>;
let Predicates = [HasSVE2p1_or_SME2] in {
defm FCLAMP_ZZZ : sve_fp_clamp<"fclamp", AArch64fclamp>;
-defm FDOT_ZZZ_S : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, int_aarch64_sve_fdot_x2>;
+defm FDOT_ZZZ_S : sve_float_dot<0b0, 0b0, ZPR32, ZPR16, "fdot", nxv8f16, AArch64fdot>;
defm FDOT_ZZZI_S : sve_float_dot_indexed<0b0, 0b00, ZPR16, ZPR3b16, "fdot", nxv8f16, int_aarch64_sve_fdot_lane_x2>;
defm BFMLSLB_ZZZ_S : sve2_fp_mla_long<0b110, "bfmlslb", nxv4f32, nxv8bf16, int_aarch64_sve_bfmlslb>;
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
new file mode 100644
index 0000000000000..9dbe096ebdb57
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fdot.ll
@@ -0,0 +1,93 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mattr=+sve2 < %s | FileCheck %s --check-prefixes=CHECK,SVE2
+; RUN: llc -global-isel -global-isel-abort=2 -mattr=+sve2 < %s | FileCheck %s --check-prefixes=CHECK,SVE2
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s --check-prefixes=CHECK,SVE2P1
+; RUN: llc -global-isel -global-isel-abort=2 -mattr=+sve2p1 < %s | FileCheck %s --check-prefixes=CHECK,SVE2P1
+
+target triple = "aarch64-linux-gnu"
+
+define <vscale x 4 x float> @fdot_wide_nxv4f32(<vscale x 4 x float> %acc, <vscale x 8 x half> %a, <vscale x 8 x half> %b) {
+; SVE2-LABEL: fdot_wide_nxv4f32:
+; SVE2: // %bb.0: // %entry
+; SVE2-NEXT: uunpklo z3.s, z1.h
+; SVE2-NEXT: uunpklo z4.s, z2.h
+; SVE2-NEXT: ptrue p0.s
+; SVE2-NEXT: uunpkhi z1.s, z1.h
+; SVE2-NEXT: uunpkhi z2.s, z2.h
+; SVE2-NEXT: fcvt z3.s, p0/m, z3.h
+; SVE2-NEXT: fcvt z4.s, p0/m, z4.h
+; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT: fmul z3.s, z3.s, z4.s
+; SVE2-NEXT: fmul z1.s, z1.s, z2.s
+; SVE2-NEXT: fadd z0.s, z0.s, z3.s
+; SVE2-NEXT: fadd z0.s, z0.s, z1.s
+; SVE2-NEXT: ret
+;
+; SVE2P1-LABEL: fdot_wide_nxv4f32:
+; SVE2P1: // %bb.0: // %entry
+; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT: ret
+entry:
+ %a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %b.wide = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
+ %mult = fmul <vscale x 8 x float> %a.wide, %b.wide
+ %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %mult)
+ ret <vscale x 4 x float> %partial.reduce
+}
+
+define <vscale x 4 x float> @fdot_splat_nxv4f32(<vscale x 4 x float> %acc, <vscale x 8 x half> %a) {
+; SVE2-LABEL: fdot_splat_nxv4f32:
+; SVE2: // %bb.0: // %entry
+; SVE2-NEXT: uunpklo z2.s, z1.h
+; SVE2-NEXT: ptrue p0.s
+; SVE2-NEXT: uunpkhi z1.s, z1.h
+; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT: fadd z0.s, z0.s, z2.s
+; SVE2-NEXT: fadd z0.s, z0.s, z1.s
+; SVE2-NEXT: ret
+;
+; SVE2P1-LABEL: fdot_splat_nxv4f32:
+; SVE2P1: // %bb.0: // %entry
+; SVE2P1-NEXT: fmov z2.h, #1.00000000
+; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT: ret
+entry:
+ %a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
+ %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %a.wide)
+ ret <vscale x 4 x float> %partial.reduce
+}
+
+define <vscale x 8 x half> @partial_reduce_nxv8f16(<vscale x 8 x half> %acc, <vscale x 16 x half> %a) {
+; CHECK-LABEL: partial_reduce_nxv8f16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fadd z0.h, z0.h, z1.h
+; CHECK-NEXT: fadd z0.h, z0.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %partial.reduce = call <vscale x 8 x half> @llvm.vector.partial.reduce.fadd(<vscale x 8 x half> %acc, <vscale x 16 x half> %a)
+ ret <vscale x 8 x half> %partial.reduce
+}
+
+define <vscale x 4 x float> @partial_reduce_nxv4f32(<vscale x 4 x float> %acc, <vscale x 8 x float> %a) {
+; CHECK-LABEL: partial_reduce_nxv4f32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fadd z0.s, z0.s, z1.s
+; CHECK-NEXT: fadd z0.s, z0.s, z2.s
+; CHECK-NEXT: ret
+entry:
+ %partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %a)
+ ret <vscale x 4 x float> %partial.reduce
+}
+
+define <vscale x 2 x double> @partial_reduce_nxv2f64(<vscale x 2 x double> %acc, <vscale x 4 x double> %a) {
+; CHECK-LABEL: partial_reduce_nxv2f64:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fadd z0.d, z0.d, z1.d
+; CHECK-NEXT: fadd z0.d, z0.d, z2.d
+; CHECK-NEXT: ret
+entry:
+ %partial.reduce = call <vscale x 2 x double> @llvm.vector.partial.reduce.fadd(<vscale x 2 x double> %acc, <vscale x 4 x double> %a)
+ ret <vscale x 2 x double> %partial.reduce
+}
diff --git a/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll b/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll
new file mode 100644
index 0000000000000..89216ce2cb72b
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2p1-fixed-length-fdot.ll
@@ -0,0 +1,230 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mattr=+sve2 < %s | FileCheck %s --check-prefixes=CHECK,SVE2
+; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s --check-prefixes=CHECK,SVE2P1
+
+target triple = "aarch64-linux-gnu"
+
+define void @fdot_wide_v8f32(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,0) {
+; SVE2-LABEL: fdot_wide_v8f32:
+; SVE2: // %bb.0: // %entry
+; SVE2-NEXT: ptrue p0.s, vl8
+; SVE2-NEXT: mov x8, #8 // =0x8
+; SVE2-NEXT: ld1h { z0.s }, p0/z, [x1]
+; SVE2-NEXT: ld1h { z1.s }, p0/z, [x2]
+; SVE2-NEXT: ld1h { z2.s }, p0/z, [x1, x8, lsl #1]
+; SVE2-NEXT: ld1h { z3.s }, p0/z, [x2, x8, lsl #1]
+; SVE2-NEXT: fcvt z0.s, p0/m, z0.h
+; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT: fcvt z3.s, p0/m, z3.h
+; SVE2-NEXT: fmul z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT: ld1w { z1.s }, p0/z, [x0]
+; SVE2-NEXT: fmul z2.s, p0/m, z2.s, z3.s
+; SVE2-NEXT: fadd z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT: fadd z0.s, p0/m, z0.s, z2.s
+; SVE2-NEXT: st1w { z0.s }, p0, [x0]
+; SVE2-NEXT: ret
+;
+; SVE2P1-LABEL: fdot_wide_v8f32:
+; SVE2P1: // %bb.0: // %entry
+; SVE2P1-NEXT: ptrue p0.s, vl8
+; SVE2P1-NEXT: ptrue p1.h, vl16
+; SVE2P1-NEXT: ld1w { z0.s }, p0/z, [x0]
+; SVE2P1-NEXT: ld1h { z1.h }, p1/z, [x1]
+; SVE2P1-NEXT: ld1h { z2.h }, p1/z, [x2]
+; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT: st1w { z0.s }, p0, [x0]
+; SVE2P1-NEXT: ret
+entry:
+ %acc = load <8 x float>, ptr %accptr
+ %a = load <16 x half>, ptr %aptr
+ %b = load <16 x half>, ptr %bptr
+ %a.wide = fpext <16 x half> %a to <16 x float>
+ %b.wide = fpext <16 x half> %b to <16 x float>
+ %mult = fmul <16 x float> %a.wide, %b.wide
+ %partial.reduce = call <8 x float> @llvm.vector.partial.reduce.fadd(<8 x float> %acc, <16 x float> %mult)
+ store <8 x float> %partial.reduce, ptr %accptr
+ ret void
+}
+
+define void @fdot_wide_v16f32(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(4,0) {
+; SVE2-LABEL: fdot_wide_v16f32:
+; SVE2: // %bb.0: // %entry
+; SVE2-NEXT: ptrue p0.s, vl16
+; SVE2-NEXT: mov x8, #16 // =0x10
+; SVE2-NEXT: ld1h { z0.s }, p0/z, [x1]
+; SVE2-NEXT: ld1h { z1.s }, p0/z, [x2]
+; SVE2-NEXT: ld1h { z2.s }, p0/z, [x1, x8, lsl #1]
+; SVE2-NEXT: ld1h { z3.s }, p0/z, [x2, x8, lsl #1]
+; SVE2-NEXT: fcvt z0.s, p0/m, z0.h
+; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT: fcvt z3.s, p0/m, z3.h
+; SVE2-NEXT: fmul z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT: ld1w { z1.s }, p0/z, [x0]
+; SVE2-NEXT: fmul z2.s, p0/m, z2.s, z3.s
+; SVE2-NEXT: fadd z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT: fadd z0.s, p0/m, z0.s, z2.s
+; SVE2-NEXT: st1w { z0.s }, p0, [x0]
+; SVE2-NEXT: ret
+;
+; SVE2P1-LABEL: fdot_wide_v16f32:
+; SVE2P1: // %bb.0: // %entry
+; SVE2P1-NEXT: ptrue p0.s, vl16
+; SVE2P1-NEXT: ptrue p1.h, vl32
+; SVE2P1-NEXT: ld1w { z0.s }, p0/z, [x0]
+; SVE2P1-NEXT: ld1h { z1.h }, p1/z, [x1]
+; SVE2P1-NEXT: ld1h { z2.h }, p1/z, [x2]
+; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT: st1w { z0.s }, p0, [x0]
+; SVE2P1-NEXT: ret
+entry:
+ %acc = load <16 x float>, ptr %accptr
+ %a = load <32 x half>, ptr %aptr
+ %b = load <32 x half>, ptr %bptr
+ %a.wide = fpext <32 x half> %a to <32 x float>
+ %b.wide = fpext <32 x half> %b to <32 x float>
+ %mult = fmul <32 x float> %a.wide, %b.wide
+ %partial.reduce = call <16 x float> @llvm.vector.partial.reduce.fadd(<16 x float> %acc, <32 x float> %mult)
+ store <16 x float> %partial.reduce, ptr %accptr
+ ret void
+}
+
+define void @fdot_wide_v32f32(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(8,0) {
+; SVE2-LABEL: fdot_wide_v32f32:
+; SVE2: // %bb.0: // %entry
+; SVE2-NEXT: ptrue p0.s, vl32
+; SVE2-NEXT: mov x8, #32 // =0x20
+; SVE2-NEXT: ld1h { z0.s }, p0/z, [x1]
+; SVE2-NEXT: ld1h { z1.s }, p0/z, [x2]
+; SVE2-NEXT: ld1h { z2.s }, p0/z, [x1, x8, lsl #1]
+; SVE2-NEXT: ld1h { z3.s }, p0/z, [x2, x8, lsl #1]
+; SVE2-NEXT: fcvt z0.s, p0/m, z0.h
+; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT: fcvt z3.s, p0/m, z3.h
+; SVE2-NEXT: fmul z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT: ld1w { z1.s }, p0/z, [x0]
+; SVE2-NEXT: fmul z2.s, p0/m, z2.s, z3.s
+; SVE2-NEXT: fadd z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT: fadd z0.s, p0/m, z0.s, z2.s
+; SVE2-NEXT: st1w { z0.s }, p0, [x0]
+; SVE2-NEXT: ret
+;
+; SVE2P1-LABEL: fdot_wide_v32f32:
+; SVE2P1: // %bb.0: // %entry
+; SVE2P1-NEXT: ptrue p0.s, vl32
+; SVE2P1-NEXT: ptrue p1.h, vl64
+; SVE2P1-NEXT: ld1w { z0.s }, p0/z, [x0]
+; SVE2P1-NEXT: ld1h { z1.h }, p1/z, [x1]
+; SVE2P1-NEXT: ld1h { z2.h }, p1/z, [x2]
+; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT: st1w { z0.s }, p0, [x0]
+; SVE2P1-NEXT: ret
+entry:
+ %acc = load <32 x float>, ptr %accptr
+ %a = load <64 x half>, ptr %aptr
+ %b = load <64 x half>, ptr %bptr
+ %a.wide = fpext <64 x half> %a to <64 x float>
+ %b.wide = fpext <64 x half> %b to <64 x float>
+ %mult = fmul <64 x float> %a.wide, %b.wide
+ %partial.reduce = call <32 x float> @llvm.vector.partial.reduce.fadd(<32 x float> %acc, <64 x float> %mult)
+ store <32 x float> %partial.reduce, ptr %accptr
+ ret void
+}
+
+define void @fdot_wide_v64f32(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(16,0) {
+; SVE2-LABEL: fdot_wide_v64f32:
+; SVE2: // %bb.0: // %entry
+; SVE2-NEXT: ptrue p0.s, vl64
+; SVE2-NEXT: mov x8, #64 // =0x40
+; SVE2-NEXT: ld1h { z0.s }, p0/z, [x1]
+; SVE2-NEXT: ld1h { z1.s }, p0/z, [x2]
+; SVE2-NEXT: ld1h { z2.s }, p0/z, [x1, x8, lsl #1]
+; SVE2-NEXT: ld1h { z3.s }, p0/z, [x2, x8, lsl #1]
+; SVE2-NEXT: fcvt z0.s, p0/m, z0.h
+; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
+; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
+; SVE2-NEXT: fcvt z3.s, p0/m, z3.h
+; SVE2-NEXT: fmul z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT: ld1w { z1.s }, p0/z, [x0]
+; SVE2-NEXT: fmul z2.s, p0/m, z2.s, z3.s
+; SVE2-NEXT: fadd z0.s, p0/m, z0.s, z1.s
+; SVE2-NEXT: fadd z0.s, p0/m, z0.s, z2.s
+; SVE2-NEXT: st1w { z0.s }, p0, [x0]
+; SVE2-NEXT: ret
+;
+; SVE2P1-LABEL: fdot_wide_v64f32:
+; SVE2P1: // %bb.0: // %entry
+; SVE2P1-NEXT: ptrue p0.s, vl64
+; SVE2P1-NEXT: ptrue p1.h, vl128
+; SVE2P1-NEXT: ld1w { z0.s }, p0/z, [x0]
+; SVE2P1-NEXT: ld1h { z1.h }, p1/z, [x1]
+; SVE2P1-NEXT: ld1h { z2.h }, p1/z, [x2]
+; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
+; SVE2P1-NEXT: st1w { z0.s }, p0, [x0]
+; SVE2P1-NEXT: ret
+entry:
+ %acc = load <64 x float>, ptr %accptr
+ %a = load <128 x half>, ptr %aptr
+ %b = load <128 x half>, ptr %bptr
+ %a.wide = fpext <128 x half> %a to <128 x float>
+ %b.wide = fpext <128 x half> %b to <128 x float>
+ %mult = fmul <128 x float> %a.wide, %b.wide
+ %partial.reduce = call <64 x float> @llvm.vector.partial.reduce.fadd(<64 x float> %acc, <128 x float> %mult)
+ store <64 x float> %partial.reduce, ptr %accptr
+ ret void
+}
+
+define <4 x float> @fixed_fdot_wide(<4 x float> %acc, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: fixed_fdot_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fcvtl v3.4s, v1.4h
+; CHECK-NEXT: fcvtl v4.4s, v2.4h
+; CHECK-NEXT: fcvtl2 v1.4s, v1.8h
+; CHECK-NEXT: fcvtl2 v2.4s, v2.8h
+; CHECK-NEXT: fmul v3.4s, v3.4s, v4.4s
+; CHECK-NEXT: fmul v1.4s, v1.4s, v2.4s
+; CHECK-NEXT: fadd v0.4s, v0.4s, v3.4s
+; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: ret
+entry:
+ %a.wide = fpext <8 x half> %a to <8 x float>
+ %b.wide = fpext <8 x half> %b to <8 x float>
+ %mult = fmul <8 x float> %a.wide, %b.wide
+ %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %mult)
+ ret <4 x float> %partial.reduce
+}
+
+define <8 x half> @partial_reduce_half(<8 x half> %acc, <16 x half> %a) {
+; CHECK-LABEL: partial_reduce_half:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT: fadd v0.8h, v0.8h, v2.8h
+; CHECK-NEXT: ret
+entry:
+ %partial.reduce = call <8 x half> @llvm.vector.partial.reduce.fadd(<8 x half> %acc, <16 x half> %a)
+ ret <8 x half> %partial.reduce
+}
+
+define <4 x float> @partial_reduce_float(<4 x float> %acc, <8 x float> %a) {
+; CHECK-LABEL: partial_reduce_float:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: fadd v0.4s, v0.4s, v2.4s
+; CHECK-NEXT: ret
+entry:
+ %partial.reduce = call <4 x float> @llvm.vector.partial.reduce.fadd(<4 x float> %acc, <8 x float> %a)
+ ret <4 x float> %partial.reduce
+}
+
+define <2 x double> @partial_reduce_double(<2 x double> %acc, <4 x double> %a) {
+; CHECK-LABEL: partial_reduce_double:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: fadd v0.2d, v0.2d, v1.2d
+; CHECK-NEXT: fadd v0.2d, v0.2d, v2.2d
+; CHECK-NEXT: ret
+entry:
+ %partial.reduce = call <2 x double> @llvm.vector.partial.reduce.fadd(<2 x double> %acc, <4 x double> %a)
+ ret <2 x double> %partial.reduce
+}
More information about the llvm-commits
mailing list