[llvm] dda1e74 - [Legalize] Add legalizations for VECREDUCE_SEQ_FADD
Cameron McInally via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 30 14:03:05 PDT 2020
Author: Cameron McInally
Date: 2020-10-30T16:02:55-05:00
New Revision: dda1e74b58bd6eac06e346e6246b906362532f46
URL: https://github.com/llvm/llvm-project/commit/dda1e74b58bd6eac06e346e6246b906362532f46
DIFF: https://github.com/llvm/llvm-project/commit/dda1e74b58bd6eac06e346e6246b906362532f46.diff
LOG: [Legalize] Add legalizations for VECREDUCE_SEQ_FADD
Add Legalization support for VECREDUCE_SEQ_FADD, so that we don't need to depend on ExpandReductionsPass.
Differential Revision: https://reviews.llvm.org/D90247
Added:
Modified:
llvm/include/llvm/CodeGen/TargetLowering.h
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/lib/CodeGen/TargetLoweringBase.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
llvm/lib/Target/ARM/ARMTargetTransformInfo.h
llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll
llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 064e7608cb6a..8922c9b8db78 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4447,6 +4447,9 @@ class TargetLowering : public TargetLoweringBase {
/// only the first Count elements of the vector are used.
SDValue expandVecReduce(SDNode *Node, SelectionDAG &DAG) const;
+ /// Expand a VECREDUCE_SEQ_* into an explicit ordered calculation.
+ SDValue expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const;
+
/// Expand an SREM or UREM using SDIV/UDIV or SDIVREM/UDIVREM, if legal.
/// Returns true if the expansion was successful.
bool expandREM(SDNode *Node, SDValue &Result, SelectionDAG &DAG) const;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 078d8ad27112..5900350a3fe0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1165,6 +1165,10 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
Action = TLI.getOperationAction(
Node->getOpcode(), Node->getOperand(0).getValueType());
break;
+ case ISD::VECREDUCE_SEQ_FADD:
+ Action = TLI.getOperationAction(
+ Node->getOpcode(), Node->getOperand(1).getValueType());
+ break;
default:
if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
Action = TargetLowering::Legal;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index d738ef9df7f1..b8360290b3ca 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -140,6 +140,9 @@ void DAGTypeLegalizer::SoftenFloatResult(SDNode *N, unsigned ResNo) {
case ISD::VECREDUCE_FMAX:
R = SoftenFloatRes_VECREDUCE(N);
break;
+ case ISD::VECREDUCE_SEQ_FADD:
+ R = SoftenFloatRes_VECREDUCE_SEQ(N);
+ break;
}
// If R is null, the sub-method took care of registering the result.
@@ -784,6 +787,10 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_VECREDUCE(SDNode *N) {
return SDValue();
}
+SDValue DAGTypeLegalizer::SoftenFloatRes_VECREDUCE_SEQ(SDNode *N) {
+ ReplaceValueWith(SDValue(N, 0), TLI.expandVecReduceSeq(N, DAG));
+ return SDValue();
+}
//===----------------------------------------------------------------------===//
// Convert Float Operand to Integer
@@ -2254,6 +2261,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
case ISD::VECREDUCE_FMAX:
R = PromoteFloatRes_VECREDUCE(N);
break;
+ case ISD::VECREDUCE_SEQ_FADD:
+ R = PromoteFloatRes_VECREDUCE_SEQ(N);
+ break;
}
if (R.getNode())
@@ -2494,6 +2504,11 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_VECREDUCE(SDNode *N) {
return SDValue();
}
+SDValue DAGTypeLegalizer::PromoteFloatRes_VECREDUCE_SEQ(SDNode *N) {
+ ReplaceValueWith(SDValue(N, 0), TLI.expandVecReduceSeq(N, DAG));
+ return SDValue();
+}
+
SDValue DAGTypeLegalizer::BitcastToInt_ATOMIC_SWAP(SDNode *N) {
EVT VT = N->getValueType(0);
@@ -2608,6 +2623,9 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
case ISD::VECREDUCE_FMAX:
R = SoftPromoteHalfRes_VECREDUCE(N);
break;
+ case ISD::VECREDUCE_SEQ_FADD:
+ R = SoftPromoteHalfRes_VECREDUCE_SEQ(N);
+ break;
}
if (R.getNode())
@@ -2806,6 +2824,12 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_VECREDUCE(SDNode *N) {
return SDValue();
}
+SDValue DAGTypeLegalizer::SoftPromoteHalfRes_VECREDUCE_SEQ(SDNode *N) {
+ // Expand and soften.
+ ReplaceValueWith(SDValue(N, 0), TLI.expandVecReduceSeq(N, DAG));
+ return SDValue();
+}
+
//===----------------------------------------------------------------------===//
// Half Operand Soft Promotion
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index cec91c5bbb5b..6ed480f5c17c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -551,6 +551,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SoftenFloatRes_VAARG(SDNode *N);
SDValue SoftenFloatRes_XINT_TO_FP(SDNode *N);
SDValue SoftenFloatRes_VECREDUCE(SDNode *N);
+ SDValue SoftenFloatRes_VECREDUCE_SEQ(SDNode *N);
// Convert Float Operand to Integer.
bool SoftenFloatOperand(SDNode *N, unsigned OpNo);
@@ -670,6 +671,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue BitcastToInt_ATOMIC_SWAP(SDNode *N);
SDValue PromoteFloatRes_XINT_TO_FP(SDNode *N);
SDValue PromoteFloatRes_VECREDUCE(SDNode *N);
+ SDValue PromoteFloatRes_VECREDUCE_SEQ(SDNode *N);
bool PromoteFloatOperand(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
@@ -708,6 +710,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SoftPromoteHalfRes_XINT_TO_FP(SDNode *N);
SDValue SoftPromoteHalfRes_UNDEF(SDNode *N);
SDValue SoftPromoteHalfRes_VECREDUCE(SDNode *N);
+ SDValue SoftPromoteHalfRes_VECREDUCE_SEQ(SDNode *N);
bool SoftPromoteHalfOperand(SDNode *N, unsigned OpNo);
SDValue SoftPromoteHalfOp_BITCAST(SDNode *N);
@@ -774,6 +777,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue ScalarizeVecOp_FP_ROUND(SDNode *N, unsigned OpNo);
SDValue ScalarizeVecOp_STRICT_FP_ROUND(SDNode *N, unsigned OpNo);
SDValue ScalarizeVecOp_VECREDUCE(SDNode *N);
+ SDValue ScalarizeVecOp_VECREDUCE_SEQ(SDNode *N);
//===--------------------------------------------------------------------===//
// Vector Splitting Support: LegalizeVectorTypes.cpp
@@ -829,6 +833,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
bool SplitVectorOperand(SDNode *N, unsigned OpNo);
SDValue SplitVecOp_VSELECT(SDNode *N, unsigned OpNo);
SDValue SplitVecOp_VECREDUCE(SDNode *N, unsigned OpNo);
+ SDValue SplitVecOp_VECREDUCE_SEQ(SDNode *N);
SDValue SplitVecOp_UnaryOp(SDNode *N);
SDValue SplitVecOp_TruncateHelper(SDNode *N);
@@ -915,6 +920,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue WidenVecOp_Convert(SDNode *N);
SDValue WidenVecOp_FCOPYSIGN(SDNode *N);
SDValue WidenVecOp_VECREDUCE(SDNode *N);
+ SDValue WidenVecOp_VECREDUCE_SEQ(SDNode *N);
/// Helper function to generate a set of operations to perform
/// a vector operation for a wider type.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index f109b0781757..0869f618dd35 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -471,10 +471,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
Node->getValueType(0), Scale);
break;
}
- case ISD::VECREDUCE_SEQ_FADD:
- Action = TLI.getOperationAction(Node->getOpcode(),
- Node->getOperand(1).getValueType());
- break;
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP:
case ISD::VECREDUCE_ADD:
@@ -493,6 +489,10 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
break;
+ case ISD::VECREDUCE_SEQ_FADD:
+ Action = TLI.getOperationAction(Node->getOpcode(),
+ Node->getOperand(1).getValueType());
+ break;
}
LLVM_DEBUG(dbgs() << "\nLegalizing vector op: "; Node->dump(&DAG));
@@ -874,6 +874,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::VECREDUCE_FMIN:
Results.push_back(TLI.expandVecReduce(Node, DAG));
return;
+ case ISD::VECREDUCE_SEQ_FADD:
+ Results.push_back(TLI.expandVecReduceSeq(Node, DAG));
+ return;
case ISD::SREM:
case ISD::UREM:
ExpandREM(Node, Results);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index b46ea1be7a30..e8186a1ee543 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -623,6 +623,9 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VECREDUCE_FMIN:
Res = ScalarizeVecOp_VECREDUCE(N);
break;
+ case ISD::VECREDUCE_SEQ_FADD:
+ Res = ScalarizeVecOp_VECREDUCE_SEQ(N);
+ break;
}
}
@@ -803,6 +806,17 @@ SDValue DAGTypeLegalizer::ScalarizeVecOp_VECREDUCE(SDNode *N) {
return Res;
}
+SDValue DAGTypeLegalizer::ScalarizeVecOp_VECREDUCE_SEQ(SDNode *N) {
+ SDValue AccOp = N->getOperand(0);
+ SDValue VecOp = N->getOperand(1);
+
+ unsigned BaseOpc = ISD::getVecReduceBaseOpcode(N->getOpcode());
+
+ SDValue Op = GetScalarizedVector(VecOp);
+ return DAG.getNode(BaseOpc, SDLoc(N), N->getValueType(0),
+ AccOp, Op, N->getFlags());
+}
+
//===----------------------------------------------------------------------===//
// Result Vector Splitting
//===----------------------------------------------------------------------===//
@@ -2075,6 +2089,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VECREDUCE_FMIN:
Res = SplitVecOp_VECREDUCE(N, OpNo);
break;
+ case ISD::VECREDUCE_SEQ_FADD:
+ Res = SplitVecOp_VECREDUCE_SEQ(N);
+ break;
}
}
@@ -2150,6 +2167,28 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECREDUCE(SDNode *N, unsigned OpNo) {
return DAG.getNode(N->getOpcode(), dl, ResVT, Partial, N->getFlags());
}
+SDValue DAGTypeLegalizer::SplitVecOp_VECREDUCE_SEQ(SDNode *N) {
+ EVT ResVT = N->getValueType(0);
+ SDValue Lo, Hi;
+ SDLoc dl(N);
+
+ SDValue AccOp = N->getOperand(0);
+ SDValue VecOp = N->getOperand(1);
+ SDNodeFlags Flags = N->getFlags();
+
+ EVT VecVT = VecOp.getValueType();
+ assert(VecVT.isVector() && "Can only split reduce vector operand");
+ GetSplitVector(VecOp, Lo, Hi);
+ EVT LoOpVT, HiOpVT;
+ std::tie(LoOpVT, HiOpVT) = DAG.GetSplitDestVTs(VecVT);
+
+ // Reduce low half.
+ SDValue Partial = DAG.getNode(N->getOpcode(), dl, ResVT, AccOp, Lo, Flags);
+
+ // Reduce high half, using low half result as initial value.
+ return DAG.getNode(N->getOpcode(), dl, ResVT, Partial, Hi, Flags);
+}
+
SDValue DAGTypeLegalizer::SplitVecOp_UnaryOp(SDNode *N) {
// The result has a legal vector type, but the input needs splitting.
EVT ResVT = N->getValueType(0);
@@ -4318,6 +4357,9 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VECREDUCE_FMIN:
Res = WidenVecOp_VECREDUCE(N);
break;
+ case ISD::VECREDUCE_SEQ_FADD:
+ Res = WidenVecOp_VECREDUCE_SEQ(N);
+ break;
}
// If Res is null, the sub-method took care of registering the result.
@@ -4757,8 +4799,9 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE(SDNode *N) {
EVT ElemVT = OrigVT.getVectorElementType();
SDNodeFlags Flags = N->getFlags();
- SDValue NeutralElem = DAG.getNeutralElement(
- ISD::getVecReduceBaseOpcode(N->getOpcode()), dl, ElemVT, Flags);
+ unsigned Opc = N->getOpcode();
+ unsigned BaseOpc = ISD::getVecReduceBaseOpcode(Opc);
+ SDValue NeutralElem = DAG.getNeutralElement(BaseOpc, dl, ElemVT, Flags);
assert(NeutralElem && "Neutral element must exist");
// Pad the vector with the neutral element.
@@ -4768,7 +4811,32 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE(SDNode *N) {
Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, WideVT, Op, NeutralElem,
DAG.getVectorIdxConstant(Idx, dl));
- return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), Op, Flags);
+ return DAG.getNode(Opc, dl, N->getValueType(0), Op, Flags);
+}
+
+SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) {
+ SDLoc dl(N);
+ SDValue AccOp = N->getOperand(0);
+ SDValue VecOp = N->getOperand(1);
+ SDValue Op = GetWidenedVector(VecOp);
+
+ EVT OrigVT = VecOp.getValueType();
+ EVT WideVT = Op.getValueType();
+ EVT ElemVT = OrigVT.getVectorElementType();
+ SDNodeFlags Flags = N->getFlags();
+
+ unsigned Opc = N->getOpcode();
+ unsigned BaseOpc = ISD::getVecReduceBaseOpcode(Opc);
+ SDValue NeutralElem = DAG.getNeutralElement(BaseOpc, dl, ElemVT, Flags);
+
+ // Pad the vector with the neutral element.
+ unsigned OrigElts = OrigVT.getVectorNumElements();
+ unsigned WideElts = WideVT.getVectorNumElements();
+ for (unsigned Idx = OrigElts; Idx < WideElts; Idx++)
+ Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, WideVT, Op, NeutralElem,
+ DAG.getVectorIdxConstant(Idx, dl));
+
+ return DAG.getNode(Opc, dl, N->getValueType(0), AccOp, Op, Flags);
}
SDValue DAGTypeLegalizer::WidenVecOp_VSELECT(SDNode *N) {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 709b0f44f0a5..9b3d904c5f8f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -338,6 +338,7 @@ ISD::NodeType ISD::getVecReduceBaseOpcode(unsigned VecReduceOpcode) {
default:
llvm_unreachable("Expected VECREDUCE opcode");
case ISD::VECREDUCE_FADD:
+ case ISD::VECREDUCE_SEQ_FADD:
return ISD::FADD;
case ISD::VECREDUCE_FMUL:
return ISD::FMUL;
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index c3c521d89c1b..58703c4f3d99 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -8030,6 +8030,28 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
return Res;
}
+SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const {
+ SDLoc dl(Node);
+ SDValue AccOp = Node->getOperand(0);
+ SDValue VecOp = Node->getOperand(1);
+ SDNodeFlags Flags = Node->getFlags();
+
+ EVT VT = VecOp.getValueType();
+ EVT EltVT = VT.getVectorElementType();
+ unsigned NumElts = VT.getVectorNumElements();
+
+ SmallVector<SDValue, 8> Ops;
+ DAG.ExtractVectorElements(VecOp, Ops, 0, NumElts);
+
+ unsigned BaseOpcode = ISD::getVecReduceBaseOpcode(Node->getOpcode());
+
+ SDValue Res = AccOp;
+ for (unsigned i = 0; i < NumElts; i++)
+ Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Flags);
+
+ return Res;
+}
+
bool TargetLowering::expandREM(SDNode *Node, SDValue &Result,
SelectionDAG &DAG) const {
EVT VT = Node->getValueType(0);
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f7dd4b395e46..a61574546367 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -733,6 +733,7 @@ void TargetLoweringBase::initActions() {
setOperationAction(ISD::VECREDUCE_UMIN, VT, Expand);
setOperationAction(ISD::VECREDUCE_FMAX, VT, Expand);
setOperationAction(ISD::VECREDUCE_FMIN, VT, Expand);
+ setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Expand);
}
// Most targets ignore the @llvm.prefetch intrinsic.
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index c3319ff4f905..047dffb85285 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -777,14 +777,6 @@ class AArch64TargetLowering : public TargetLowering {
return !useSVEForFixedLengthVectors();
}
- // FIXME: Move useSVEForFixedLengthVectors*() back to private scope once
- // reduction legalization is complete.
- bool useSVEForFixedLengthVectors() const;
- // Normally SVE is only used for byte size vectors that do not fit within a
- // NEON vector. This changes when OverrideNEON is true, allowing SVE to be
- // used for 64bit and 128bit vectors as well.
- bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const;
-
private:
/// Keep a pointer to the AArch64Subtarget around so that we can
/// make the right decision when generating code for
diff erent targets.
@@ -1012,6 +1004,12 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldLocalize(const MachineInstr &MI,
const TargetTransformInfo *TTI) const override;
+
+ bool useSVEForFixedLengthVectors() const;
+ // Normally SVE is only used for byte size vectors that do not fit within a
+ // NEON vector. This changes when OverrideNEON is true, allowing SVE to be
+ // used for 64bit and 128bit vectors as well.
+ bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const;
};
namespace AArch64 {
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index f3ebdc4cc781..0d81ca1fbdd8 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -223,17 +223,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
bool shouldExpandReduction(const IntrinsicInst *II) const {
switch (II->getIntrinsicID()) {
- case Intrinsic::vector_reduce_fadd: {
- Value *VecOp = II->getArgOperand(1);
- EVT VT = TLI->getValueType(getDataLayout(), VecOp->getType());
- if (ST->hasSVE() &&
- TLI->useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true))
- return false;
-
- return !II->getFastMathFlags().allowReassoc();
- }
case Intrinsic::vector_reduce_fmul:
- // We don't have legalization support for ordered FP reductions.
+ // We don't have legalization support for ordered FMUL reductions.
return !II->getFastMathFlags().allowReassoc();
default:
diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 890e905f52a8..5eddcf4ec802 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -195,9 +195,8 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
bool shouldExpandReduction(const IntrinsicInst *II) const {
switch (II->getIntrinsicID()) {
- case Intrinsic::vector_reduce_fadd:
case Intrinsic::vector_reduce_fmul:
- // We don't have legalization support for ordered FP reductions.
+ // We don't have legalization support for ordered FMUL reductions.
return !II->getFastMathFlags().allowReassoc();
default:
// Don't expand anything else, let legalization deal with it.
diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll
index 9c9da6045049..c89540fee790 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll
@@ -63,8 +63,13 @@ define half @fadda_v32f16(half %start, <32 x half>* %a) #0 {
; VBITS_GE_512-NEXT: ret
; Ensure sensible type legalisation.
-; VBITS_EQ_256-COUNT-32: fadd
-; VBITS_EQ_256: ret
+; VBITS_EQ_256: add x8, x0, #32
+; VBITS_EQ_256-NEXT: ptrue [[PG:p[0-9]+]].h, vl16
+; VBITS_EQ_256-DAG: ld1h { [[LO:z[0-9]+]].h }, [[PG]]/z, [x0]
+; VBITS_EQ_256-DAG: ld1h { [[HI:z[0-9]+]].h }, [[PG]]/z, [x8]
+; VBITS_EQ_256-NEXT: fadda h0, [[PG]], h0, [[LO]].h
+; VBITS_EQ_256-NEXT: fadda h0, [[PG]], h0, [[HI]].h
+; VBITS_EQ_256-NEXT: ret
%op = load <32 x half>, <32 x half>* %a
%res = call half @llvm.vector.reduce.fadd.v32f16(half %start, <32 x half> %op)
ret half %res
@@ -131,8 +136,13 @@ define float @fadda_v16f32(float %start, <16 x float>* %a) #0 {
; VBITS_GE_512-NEXT: ret
; Ensure sensible type legalisation.
-; VBITS_EQ_256-COUNT-16: fadd
-; VBITS_EQ_256: ret
+; VBITS_EQ_256: add x8, x0, #32
+; VBITS_EQ_256-NEXT: ptrue [[PG:p[0-9]+]].s, vl8
+; VBITS_EQ_256-DAG: ld1w { [[LO:z[0-9]+]].s }, [[PG]]/z, [x0]
+; VBITS_EQ_256-DAG: ld1w { [[HI:z[0-9]+]].s }, [[PG]]/z, [x8]
+; VBITS_EQ_256-NEXT: fadda s0, [[PG]], s0, [[LO]].s
+; VBITS_EQ_256-NEXT: fadda s0, [[PG]], s0, [[HI]].s
+; VBITS_EQ_256-NEXT: ret
%op = load <16 x float>, <16 x float>* %a
%res = call float @llvm.vector.reduce.fadd.v16f32(float %start, <16 x float> %op)
ret float %res
@@ -199,8 +209,13 @@ define double @fadda_v8f64(double %start, <8 x double>* %a) #0 {
; VBITS_GE_512-NEXT: ret
; Ensure sensible type legalisation.
-; VBITS_EQ_256-COUNT-8: fadd
-; VBITS_EQ_256: ret
+; VBITS_EQ_256: add x8, x0, #32
+; VBITS_EQ_256-NEXT: ptrue [[PG:p[0-9]+]].d, vl4
+; VBITS_EQ_256-DAG: ld1d { [[LO:z[0-9]+]].d }, [[PG]]/z, [x0]
+; VBITS_EQ_256-DAG: ld1d { [[HI:z[0-9]+]].d }, [[PG]]/z, [x8]
+; VBITS_EQ_256-NEXT: fadda d0, [[PG]], d0, [[LO]].d
+; VBITS_EQ_256-NEXT: fadda d0, [[PG]], d0, [[HI]].d
+; VBITS_EQ_256-NEXT: ret
%op = load <8 x double>, <8 x double>* %a
%res = call double @llvm.vector.reduce.fadd.v8f64(double %start, <8 x double> %op)
ret double %res
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll b/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll
index 7b957cb11b6e..683ac19ccc5b 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll
@@ -108,9 +108,9 @@ define float @test_v3f32(<3 x float> %a, float %s) nounwind {
define float @test_v3f32_neutral(<3 x float> %a) nounwind {
; CHECK-LABEL: test_v3f32_neutral:
; CHECK: // %bb.0:
-; CHECK-NEXT: faddp s1, v0.2s
-; CHECK-NEXT: mov s0, v0.s[2]
-; CHECK-NEXT: fadd s0, s1, s0
+; CHECK-NEXT: mov s1, v0.s[2]
+; CHECK-NEXT: faddp s0, v0.2s
+; CHECK-NEXT: fadd s0, s0, s1
; CHECK-NEXT: ret
%b = call float @llvm.vector.reduce.fadd.f32.v3f32(float -0.0, <3 x float> %a)
ret float %b
@@ -173,34 +173,34 @@ define fp128 @test_v2f128_neutral(<2 x fp128> %a) nounwind {
define float @test_v16f32(<16 x float> %a, float %s) nounwind {
; CHECK-LABEL: test_v16f32:
; CHECK: // %bb.0:
-; CHECK-NEXT: fadd s4, s4, s0
-; CHECK-NEXT: mov s5, v0.s[1]
-; CHECK-NEXT: fadd s4, s4, s5
-; CHECK-NEXT: mov s5, v0.s[2]
-; CHECK-NEXT: mov s0, v0.s[3]
-; CHECK-NEXT: fadd s4, s4, s5
+; CHECK-NEXT: mov s22, v0.s[3]
+; CHECK-NEXT: mov s23, v0.s[2]
+; CHECK-NEXT: mov s24, v0.s[1]
; CHECK-NEXT: fadd s0, s4, s0
-; CHECK-NEXT: mov s5, v1.s[1]
-; CHECK-NEXT: fadd s0, s0, s1
-; CHECK-NEXT: mov s4, v1.s[2]
-; CHECK-NEXT: fadd s0, s0, s5
-; CHECK-NEXT: mov s1, v1.s[3]
-; CHECK-NEXT: fadd s0, s0, s4
+; CHECK-NEXT: fadd s0, s0, s24
+; CHECK-NEXT: fadd s0, s0, s23
+; CHECK-NEXT: fadd s0, s0, s22
+; CHECK-NEXT: mov s21, v1.s[1]
; CHECK-NEXT: fadd s0, s0, s1
-; CHECK-NEXT: mov s5, v2.s[1]
+; CHECK-NEXT: mov s20, v1.s[2]
+; CHECK-NEXT: fadd s0, s0, s21
+; CHECK-NEXT: mov s19, v1.s[3]
+; CHECK-NEXT: fadd s0, s0, s20
+; CHECK-NEXT: fadd s0, s0, s19
+; CHECK-NEXT: mov s18, v2.s[1]
; CHECK-NEXT: fadd s0, s0, s2
-; CHECK-NEXT: mov s4, v2.s[2]
-; CHECK-NEXT: fadd s0, s0, s5
-; CHECK-NEXT: mov s1, v2.s[3]
-; CHECK-NEXT: fadd s0, s0, s4
-; CHECK-NEXT: fadd s0, s0, s1
-; CHECK-NEXT: mov s2, v3.s[1]
+; CHECK-NEXT: mov s17, v2.s[2]
+; CHECK-NEXT: fadd s0, s0, s18
+; CHECK-NEXT: mov s16, v2.s[3]
+; CHECK-NEXT: fadd s0, s0, s17
+; CHECK-NEXT: fadd s0, s0, s16
+; CHECK-NEXT: mov s7, v3.s[1]
; CHECK-NEXT: fadd s0, s0, s3
-; CHECK-NEXT: mov s5, v3.s[2]
-; CHECK-NEXT: fadd s0, s0, s2
+; CHECK-NEXT: mov s6, v3.s[2]
+; CHECK-NEXT: fadd s0, s0, s7
+; CHECK-NEXT: mov s5, v3.s[3]
+; CHECK-NEXT: fadd s0, s0, s6
; CHECK-NEXT: fadd s0, s0, s5
-; CHECK-NEXT: mov s1, v3.s[3]
-; CHECK-NEXT: fadd s0, s0, s1
; CHECK-NEXT: ret
%b = call float @llvm.vector.reduce.fadd.f32.v16f32(float %s, <16 x float> %a)
ret float %b
@@ -209,32 +209,32 @@ define float @test_v16f32(<16 x float> %a, float %s) nounwind {
define float @test_v16f32_neutral(<16 x float> %a) nounwind {
; CHECK-LABEL: test_v16f32_neutral:
; CHECK: // %bb.0:
-; CHECK-NEXT: faddp s4, v0.2s
-; CHECK-NEXT: mov s5, v0.s[2]
-; CHECK-NEXT: mov s0, v0.s[3]
-; CHECK-NEXT: fadd s4, s4, s5
-; CHECK-NEXT: fadd s0, s4, s0
-; CHECK-NEXT: mov s5, v1.s[1]
-; CHECK-NEXT: fadd s0, s0, s1
-; CHECK-NEXT: mov s4, v1.s[2]
-; CHECK-NEXT: fadd s0, s0, s5
-; CHECK-NEXT: mov s1, v1.s[3]
-; CHECK-NEXT: fadd s0, s0, s4
+; CHECK-NEXT: mov s21, v0.s[3]
+; CHECK-NEXT: mov s22, v0.s[2]
+; CHECK-NEXT: faddp s0, v0.2s
+; CHECK-NEXT: fadd s0, s0, s22
+; CHECK-NEXT: fadd s0, s0, s21
+; CHECK-NEXT: mov s20, v1.s[1]
; CHECK-NEXT: fadd s0, s0, s1
-; CHECK-NEXT: mov s5, v2.s[1]
+; CHECK-NEXT: mov s19, v1.s[2]
+; CHECK-NEXT: fadd s0, s0, s20
+; CHECK-NEXT: mov s18, v1.s[3]
+; CHECK-NEXT: fadd s0, s0, s19
+; CHECK-NEXT: fadd s0, s0, s18
+; CHECK-NEXT: mov s17, v2.s[1]
; CHECK-NEXT: fadd s0, s0, s2
-; CHECK-NEXT: mov s4, v2.s[2]
-; CHECK-NEXT: fadd s0, s0, s5
-; CHECK-NEXT: mov s1, v2.s[3]
-; CHECK-NEXT: fadd s0, s0, s4
-; CHECK-NEXT: fadd s0, s0, s1
-; CHECK-NEXT: mov s2, v3.s[1]
+; CHECK-NEXT: mov s16, v2.s[2]
+; CHECK-NEXT: fadd s0, s0, s17
+; CHECK-NEXT: mov s7, v2.s[3]
+; CHECK-NEXT: fadd s0, s0, s16
+; CHECK-NEXT: fadd s0, s0, s7
+; CHECK-NEXT: mov s6, v3.s[1]
; CHECK-NEXT: fadd s0, s0, s3
; CHECK-NEXT: mov s5, v3.s[2]
-; CHECK-NEXT: fadd s0, s0, s2
+; CHECK-NEXT: fadd s0, s0, s6
+; CHECK-NEXT: mov s4, v3.s[3]
; CHECK-NEXT: fadd s0, s0, s5
-; CHECK-NEXT: mov s1, v3.s[3]
-; CHECK-NEXT: fadd s0, s0, s1
+; CHECK-NEXT: fadd s0, s0, s4
; CHECK-NEXT: ret
%b = call float @llvm.vector.reduce.fadd.f32.v16f32(float -0.0, <16 x float> %a)
ret float %b
More information about the llvm-commits
mailing list