[llvm] dda1e74 - [Legalize] Add legalizations for VECREDUCE_SEQ_FADD

Cameron McInally via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 30 14:03:05 PDT 2020


Author: Cameron McInally
Date: 2020-10-30T16:02:55-05:00
New Revision: dda1e74b58bd6eac06e346e6246b906362532f46

URL: https://github.com/llvm/llvm-project/commit/dda1e74b58bd6eac06e346e6246b906362532f46
DIFF: https://github.com/llvm/llvm-project/commit/dda1e74b58bd6eac06e346e6246b906362532f46.diff

LOG: [Legalize] Add legalizations for VECREDUCE_SEQ_FADD

Add Legalization support for VECREDUCE_SEQ_FADD, so that we don't need to depend on ExpandReductionsPass.

Differential Revision: https://reviews.llvm.org/D90247

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/CodeGen/TargetLoweringBase.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
    llvm/lib/Target/ARM/ARMTargetTransformInfo.h
    llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll
    llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 064e7608cb6a..8922c9b8db78 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -4447,6 +4447,9 @@ class TargetLowering : public TargetLoweringBase {
   /// only the first Count elements of the vector are used.
   SDValue expandVecReduce(SDNode *Node, SelectionDAG &DAG) const;
 
+  /// Expand a VECREDUCE_SEQ_* into an explicit ordered calculation.
+  SDValue expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const;
+
   /// Expand an SREM or UREM using SDIV/UDIV or SDIVREM/UDIVREM, if legal.
   /// Returns true if the expansion was successful.
   bool expandREM(SDNode *Node, SDValue &Result, SelectionDAG &DAG) const;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index 078d8ad27112..5900350a3fe0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1165,6 +1165,10 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
     Action = TLI.getOperationAction(
         Node->getOpcode(), Node->getOperand(0).getValueType());
     break;
+  case ISD::VECREDUCE_SEQ_FADD:
+    Action = TLI.getOperationAction(
+        Node->getOpcode(), Node->getOperand(1).getValueType());
+    break;
   default:
     if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
       Action = TargetLowering::Legal;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
index d738ef9df7f1..b8360290b3ca 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
@@ -140,6 +140,9 @@ void DAGTypeLegalizer::SoftenFloatResult(SDNode *N, unsigned ResNo) {
     case ISD::VECREDUCE_FMAX:
       R = SoftenFloatRes_VECREDUCE(N);
       break;
+    case ISD::VECREDUCE_SEQ_FADD:
+      R = SoftenFloatRes_VECREDUCE_SEQ(N);
+      break;
   }
 
   // If R is null, the sub-method took care of registering the result.
@@ -784,6 +787,10 @@ SDValue DAGTypeLegalizer::SoftenFloatRes_VECREDUCE(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGTypeLegalizer::SoftenFloatRes_VECREDUCE_SEQ(SDNode *N) {
+  ReplaceValueWith(SDValue(N, 0), TLI.expandVecReduceSeq(N, DAG));
+  return SDValue();
+}
 
 //===----------------------------------------------------------------------===//
 //  Convert Float Operand to Integer
@@ -2254,6 +2261,9 @@ void DAGTypeLegalizer::PromoteFloatResult(SDNode *N, unsigned ResNo) {
     case ISD::VECREDUCE_FMAX:
       R = PromoteFloatRes_VECREDUCE(N);
       break;
+    case ISD::VECREDUCE_SEQ_FADD:
+      R = PromoteFloatRes_VECREDUCE_SEQ(N);
+      break;
   }
 
   if (R.getNode())
@@ -2494,6 +2504,11 @@ SDValue DAGTypeLegalizer::PromoteFloatRes_VECREDUCE(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGTypeLegalizer::PromoteFloatRes_VECREDUCE_SEQ(SDNode *N) {
+  ReplaceValueWith(SDValue(N, 0), TLI.expandVecReduceSeq(N, DAG));
+  return SDValue();
+}
+
 SDValue DAGTypeLegalizer::BitcastToInt_ATOMIC_SWAP(SDNode *N) {
   EVT VT = N->getValueType(0);
 
@@ -2608,6 +2623,9 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) {
   case ISD::VECREDUCE_FMAX:
     R = SoftPromoteHalfRes_VECREDUCE(N);
     break;
+  case ISD::VECREDUCE_SEQ_FADD:
+    R = SoftPromoteHalfRes_VECREDUCE_SEQ(N);
+    break;
   }
 
   if (R.getNode())
@@ -2806,6 +2824,12 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_VECREDUCE(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGTypeLegalizer::SoftPromoteHalfRes_VECREDUCE_SEQ(SDNode *N) {
+  // Expand and soften.
+  ReplaceValueWith(SDValue(N, 0), TLI.expandVecReduceSeq(N, DAG));
+  return SDValue();
+}
+
 //===----------------------------------------------------------------------===//
 //  Half Operand Soft Promotion
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index cec91c5bbb5b..6ed480f5c17c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -551,6 +551,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue SoftenFloatRes_VAARG(SDNode *N);
   SDValue SoftenFloatRes_XINT_TO_FP(SDNode *N);
   SDValue SoftenFloatRes_VECREDUCE(SDNode *N);
+  SDValue SoftenFloatRes_VECREDUCE_SEQ(SDNode *N);
 
   // Convert Float Operand to Integer.
   bool SoftenFloatOperand(SDNode *N, unsigned OpNo);
@@ -670,6 +671,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue BitcastToInt_ATOMIC_SWAP(SDNode *N);
   SDValue PromoteFloatRes_XINT_TO_FP(SDNode *N);
   SDValue PromoteFloatRes_VECREDUCE(SDNode *N);
+  SDValue PromoteFloatRes_VECREDUCE_SEQ(SDNode *N);
 
   bool PromoteFloatOperand(SDNode *N, unsigned OpNo);
   SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
@@ -708,6 +710,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue SoftPromoteHalfRes_XINT_TO_FP(SDNode *N);
   SDValue SoftPromoteHalfRes_UNDEF(SDNode *N);
   SDValue SoftPromoteHalfRes_VECREDUCE(SDNode *N);
+  SDValue SoftPromoteHalfRes_VECREDUCE_SEQ(SDNode *N);
 
   bool SoftPromoteHalfOperand(SDNode *N, unsigned OpNo);
   SDValue SoftPromoteHalfOp_BITCAST(SDNode *N);
@@ -774,6 +777,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue ScalarizeVecOp_FP_ROUND(SDNode *N, unsigned OpNo);
   SDValue ScalarizeVecOp_STRICT_FP_ROUND(SDNode *N, unsigned OpNo);
   SDValue ScalarizeVecOp_VECREDUCE(SDNode *N);
+  SDValue ScalarizeVecOp_VECREDUCE_SEQ(SDNode *N);
 
   //===--------------------------------------------------------------------===//
   // Vector Splitting Support: LegalizeVectorTypes.cpp
@@ -829,6 +833,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   bool SplitVectorOperand(SDNode *N, unsigned OpNo);
   SDValue SplitVecOp_VSELECT(SDNode *N, unsigned OpNo);
   SDValue SplitVecOp_VECREDUCE(SDNode *N, unsigned OpNo);
+  SDValue SplitVecOp_VECREDUCE_SEQ(SDNode *N);
   SDValue SplitVecOp_UnaryOp(SDNode *N);
   SDValue SplitVecOp_TruncateHelper(SDNode *N);
 
@@ -915,6 +920,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue WidenVecOp_Convert(SDNode *N);
   SDValue WidenVecOp_FCOPYSIGN(SDNode *N);
   SDValue WidenVecOp_VECREDUCE(SDNode *N);
+  SDValue WidenVecOp_VECREDUCE_SEQ(SDNode *N);
 
   /// Helper function to generate a set of operations to perform
   /// a vector operation for a wider type.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index f109b0781757..0869f618dd35 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -471,10 +471,6 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
                                               Node->getValueType(0), Scale);
     break;
   }
-  case ISD::VECREDUCE_SEQ_FADD:
-    Action = TLI.getOperationAction(Node->getOpcode(),
-                                    Node->getOperand(1).getValueType());
-    break;
   case ISD::SINT_TO_FP:
   case ISD::UINT_TO_FP:
   case ISD::VECREDUCE_ADD:
@@ -493,6 +489,10 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
     Action = TLI.getOperationAction(Node->getOpcode(),
                                     Node->getOperand(0).getValueType());
     break;
+  case ISD::VECREDUCE_SEQ_FADD:
+    Action = TLI.getOperationAction(Node->getOpcode(),
+                                    Node->getOperand(1).getValueType());
+    break;
   }
 
   LLVM_DEBUG(dbgs() << "\nLegalizing vector op: "; Node->dump(&DAG));
@@ -874,6 +874,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::VECREDUCE_FMIN:
     Results.push_back(TLI.expandVecReduce(Node, DAG));
     return;
+  case ISD::VECREDUCE_SEQ_FADD:
+    Results.push_back(TLI.expandVecReduceSeq(Node, DAG));
+    return;
   case ISD::SREM:
   case ISD::UREM:
     ExpandREM(Node, Results);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index b46ea1be7a30..e8186a1ee543 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -623,6 +623,9 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
     case ISD::VECREDUCE_FMIN:
       Res = ScalarizeVecOp_VECREDUCE(N);
       break;
+    case ISD::VECREDUCE_SEQ_FADD:
+      Res = ScalarizeVecOp_VECREDUCE_SEQ(N);
+      break;
     }
   }
 
@@ -803,6 +806,17 @@ SDValue DAGTypeLegalizer::ScalarizeVecOp_VECREDUCE(SDNode *N) {
   return Res;
 }
 
+SDValue DAGTypeLegalizer::ScalarizeVecOp_VECREDUCE_SEQ(SDNode *N) {
+  SDValue AccOp = N->getOperand(0);
+  SDValue VecOp = N->getOperand(1);
+
+  unsigned BaseOpc = ISD::getVecReduceBaseOpcode(N->getOpcode());
+
+  SDValue Op = GetScalarizedVector(VecOp);
+  return DAG.getNode(BaseOpc, SDLoc(N), N->getValueType(0),
+                     AccOp, Op, N->getFlags());
+}
+
 //===----------------------------------------------------------------------===//
 //  Result Vector Splitting
 //===----------------------------------------------------------------------===//
@@ -2075,6 +2089,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
     case ISD::VECREDUCE_FMIN:
       Res = SplitVecOp_VECREDUCE(N, OpNo);
       break;
+    case ISD::VECREDUCE_SEQ_FADD:
+      Res = SplitVecOp_VECREDUCE_SEQ(N);
+      break;
     }
   }
 
@@ -2150,6 +2167,28 @@ SDValue DAGTypeLegalizer::SplitVecOp_VECREDUCE(SDNode *N, unsigned OpNo) {
   return DAG.getNode(N->getOpcode(), dl, ResVT, Partial, N->getFlags());
 }
 
+SDValue DAGTypeLegalizer::SplitVecOp_VECREDUCE_SEQ(SDNode *N) {
+  EVT ResVT = N->getValueType(0);
+  SDValue Lo, Hi;
+  SDLoc dl(N);
+
+  SDValue AccOp = N->getOperand(0);
+  SDValue VecOp = N->getOperand(1);
+  SDNodeFlags Flags = N->getFlags();
+
+  EVT VecVT = VecOp.getValueType();
+  assert(VecVT.isVector() && "Can only split reduce vector operand");
+  GetSplitVector(VecOp, Lo, Hi);
+  EVT LoOpVT, HiOpVT;
+  std::tie(LoOpVT, HiOpVT) = DAG.GetSplitDestVTs(VecVT);
+
+  // Reduce low half.
+  SDValue Partial = DAG.getNode(N->getOpcode(), dl, ResVT, AccOp, Lo, Flags);
+
+  // Reduce high half, using low half result as initial value.
+  return DAG.getNode(N->getOpcode(), dl, ResVT, Partial, Hi, Flags);
+}
+
 SDValue DAGTypeLegalizer::SplitVecOp_UnaryOp(SDNode *N) {
   // The result has a legal vector type, but the input needs splitting.
   EVT ResVT = N->getValueType(0);
@@ -4318,6 +4357,9 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
   case ISD::VECREDUCE_FMIN:
     Res = WidenVecOp_VECREDUCE(N);
     break;
+  case ISD::VECREDUCE_SEQ_FADD:
+    Res = WidenVecOp_VECREDUCE_SEQ(N);
+    break;
   }
 
   // If Res is null, the sub-method took care of registering the result.
@@ -4757,8 +4799,9 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE(SDNode *N) {
   EVT ElemVT = OrigVT.getVectorElementType();
   SDNodeFlags Flags = N->getFlags();
 
-  SDValue NeutralElem = DAG.getNeutralElement(
-      ISD::getVecReduceBaseOpcode(N->getOpcode()), dl, ElemVT, Flags);
+  unsigned Opc = N->getOpcode();
+  unsigned BaseOpc = ISD::getVecReduceBaseOpcode(Opc);
+  SDValue NeutralElem = DAG.getNeutralElement(BaseOpc, dl, ElemVT, Flags);
   assert(NeutralElem && "Neutral element must exist");
 
   // Pad the vector with the neutral element.
@@ -4768,7 +4811,32 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE(SDNode *N) {
     Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, WideVT, Op, NeutralElem,
                      DAG.getVectorIdxConstant(Idx, dl));
 
-  return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), Op, Flags);
+  return DAG.getNode(Opc, dl, N->getValueType(0), Op, Flags);
+}
+
+SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) {
+  SDLoc dl(N);
+  SDValue AccOp = N->getOperand(0);
+  SDValue VecOp = N->getOperand(1);
+  SDValue Op = GetWidenedVector(VecOp);
+
+  EVT OrigVT = VecOp.getValueType();
+  EVT WideVT = Op.getValueType();
+  EVT ElemVT = OrigVT.getVectorElementType();
+  SDNodeFlags Flags = N->getFlags();
+
+  unsigned Opc = N->getOpcode();
+  unsigned BaseOpc = ISD::getVecReduceBaseOpcode(Opc);
+  SDValue NeutralElem = DAG.getNeutralElement(BaseOpc, dl, ElemVT, Flags);
+
+  // Pad the vector with the neutral element.
+  unsigned OrigElts = OrigVT.getVectorNumElements();
+  unsigned WideElts = WideVT.getVectorNumElements();
+  for (unsigned Idx = OrigElts; Idx < WideElts; Idx++)
+    Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, WideVT, Op, NeutralElem,
+                     DAG.getVectorIdxConstant(Idx, dl));
+
+  return DAG.getNode(Opc, dl, N->getValueType(0), AccOp, Op, Flags);
 }
 
 SDValue DAGTypeLegalizer::WidenVecOp_VSELECT(SDNode *N) {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 709b0f44f0a5..9b3d904c5f8f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -338,6 +338,7 @@ ISD::NodeType ISD::getVecReduceBaseOpcode(unsigned VecReduceOpcode) {
   default:
     llvm_unreachable("Expected VECREDUCE opcode");
   case ISD::VECREDUCE_FADD:
+  case ISD::VECREDUCE_SEQ_FADD:
     return ISD::FADD;
   case ISD::VECREDUCE_FMUL:
     return ISD::FMUL;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index c3c521d89c1b..58703c4f3d99 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -8030,6 +8030,28 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
   return Res;
 }
 
+SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const {
+  SDLoc dl(Node);
+  SDValue AccOp = Node->getOperand(0);
+  SDValue VecOp = Node->getOperand(1);
+  SDNodeFlags Flags = Node->getFlags();
+
+  EVT VT = VecOp.getValueType();
+  EVT EltVT = VT.getVectorElementType();
+  unsigned NumElts = VT.getVectorNumElements();
+
+  SmallVector<SDValue, 8> Ops;
+  DAG.ExtractVectorElements(VecOp, Ops, 0, NumElts);
+
+  unsigned BaseOpcode = ISD::getVecReduceBaseOpcode(Node->getOpcode());
+
+  SDValue Res = AccOp;
+  for (unsigned i = 0; i < NumElts; i++)
+    Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Flags);
+
+  return Res;
+}
+
 bool TargetLowering::expandREM(SDNode *Node, SDValue &Result,
                                SelectionDAG &DAG) const {
   EVT VT = Node->getValueType(0);

diff  --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index f7dd4b395e46..a61574546367 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -733,6 +733,7 @@ void TargetLoweringBase::initActions() {
     setOperationAction(ISD::VECREDUCE_UMIN, VT, Expand);
     setOperationAction(ISD::VECREDUCE_FMAX, VT, Expand);
     setOperationAction(ISD::VECREDUCE_FMIN, VT, Expand);
+    setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Expand);
   }
 
   // Most targets ignore the @llvm.prefetch intrinsic.

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index c3319ff4f905..047dffb85285 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -777,14 +777,6 @@ class AArch64TargetLowering : public TargetLowering {
     return !useSVEForFixedLengthVectors();
   }
 
-  // FIXME: Move useSVEForFixedLengthVectors*() back to private scope once
-  // reduction legalization is complete.      
-  bool useSVEForFixedLengthVectors() const;
-  // Normally SVE is only used for byte size vectors that do not fit within a
-  // NEON vector. This changes when OverrideNEON is true, allowing SVE to be
-  // used for 64bit and 128bit vectors as well.
-  bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const;
-
 private:
   /// Keep a pointer to the AArch64Subtarget around so that we can
   /// make the right decision when generating code for 
diff erent targets.
@@ -1012,6 +1004,12 @@ class AArch64TargetLowering : public TargetLowering {
 
   bool shouldLocalize(const MachineInstr &MI,
                       const TargetTransformInfo *TTI) const override;
+
+  bool useSVEForFixedLengthVectors() const;
+  // Normally SVE is only used for byte size vectors that do not fit within a
+  // NEON vector. This changes when OverrideNEON is true, allowing SVE to be
+  // used for 64bit and 128bit vectors as well.
+  bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const;
 };
 
 namespace AArch64 {

diff  --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index f3ebdc4cc781..0d81ca1fbdd8 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -223,17 +223,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
 
   bool shouldExpandReduction(const IntrinsicInst *II) const {
     switch (II->getIntrinsicID()) {
-    case Intrinsic::vector_reduce_fadd: {
-      Value *VecOp = II->getArgOperand(1);
-      EVT VT = TLI->getValueType(getDataLayout(), VecOp->getType());
-      if (ST->hasSVE() &&
-          TLI->useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true))
-        return false;
-
-      return !II->getFastMathFlags().allowReassoc();
-    }
     case Intrinsic::vector_reduce_fmul:
-      // We don't have legalization support for ordered FP reductions.
+      // We don't have legalization support for ordered FMUL reductions.
       return !II->getFastMathFlags().allowReassoc();
 
     default:

diff  --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
index 890e905f52a8..5eddcf4ec802 100644
--- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
+++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h
@@ -195,9 +195,8 @@ class ARMTTIImpl : public BasicTTIImplBase<ARMTTIImpl> {
 
   bool shouldExpandReduction(const IntrinsicInst *II) const {
     switch (II->getIntrinsicID()) {
-    case Intrinsic::vector_reduce_fadd:
     case Intrinsic::vector_reduce_fmul:
-      // We don't have legalization support for ordered FP reductions.
+      // We don't have legalization support for ordered FMUL reductions.
       return !II->getFastMathFlags().allowReassoc();
     default:
       // Don't expand anything else, let legalization deal with it.

diff  --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll
index 9c9da6045049..c89540fee790 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll
@@ -63,8 +63,13 @@ define half @fadda_v32f16(half %start, <32 x half>* %a) #0 {
 ; VBITS_GE_512-NEXT: ret
 
 ; Ensure sensible type legalisation.
-; VBITS_EQ_256-COUNT-32: fadd
-; VBITS_EQ_256: ret
+; VBITS_EQ_256: add x8, x0, #32
+; VBITS_EQ_256-NEXT: ptrue [[PG:p[0-9]+]].h, vl16
+; VBITS_EQ_256-DAG: ld1h { [[LO:z[0-9]+]].h }, [[PG]]/z, [x0]
+; VBITS_EQ_256-DAG: ld1h { [[HI:z[0-9]+]].h }, [[PG]]/z, [x8]
+; VBITS_EQ_256-NEXT: fadda h0, [[PG]], h0, [[LO]].h
+; VBITS_EQ_256-NEXT: fadda h0, [[PG]], h0, [[HI]].h
+; VBITS_EQ_256-NEXT: ret
   %op = load <32 x half>, <32 x half>* %a
   %res = call half @llvm.vector.reduce.fadd.v32f16(half %start, <32 x half> %op)
   ret half %res
@@ -131,8 +136,13 @@ define float @fadda_v16f32(float %start, <16 x float>* %a) #0 {
 ; VBITS_GE_512-NEXT: ret
 
 ; Ensure sensible type legalisation.
-; VBITS_EQ_256-COUNT-16: fadd
-; VBITS_EQ_256: ret
+; VBITS_EQ_256: add x8, x0, #32
+; VBITS_EQ_256-NEXT: ptrue [[PG:p[0-9]+]].s, vl8
+; VBITS_EQ_256-DAG: ld1w { [[LO:z[0-9]+]].s }, [[PG]]/z, [x0]
+; VBITS_EQ_256-DAG: ld1w { [[HI:z[0-9]+]].s }, [[PG]]/z, [x8]
+; VBITS_EQ_256-NEXT: fadda s0, [[PG]], s0, [[LO]].s
+; VBITS_EQ_256-NEXT: fadda s0, [[PG]], s0, [[HI]].s
+; VBITS_EQ_256-NEXT: ret
   %op = load <16 x float>, <16 x float>* %a
   %res = call float @llvm.vector.reduce.fadd.v16f32(float %start, <16 x float> %op)
   ret float %res
@@ -199,8 +209,13 @@ define double @fadda_v8f64(double %start, <8 x double>* %a) #0 {
 ; VBITS_GE_512-NEXT: ret
 
 ; Ensure sensible type legalisation.
-; VBITS_EQ_256-COUNT-8: fadd
-; VBITS_EQ_256: ret
+; VBITS_EQ_256: add x8, x0, #32
+; VBITS_EQ_256-NEXT: ptrue [[PG:p[0-9]+]].d, vl4
+; VBITS_EQ_256-DAG: ld1d { [[LO:z[0-9]+]].d }, [[PG]]/z, [x0]
+; VBITS_EQ_256-DAG: ld1d { [[HI:z[0-9]+]].d }, [[PG]]/z, [x8]
+; VBITS_EQ_256-NEXT: fadda d0, [[PG]], d0, [[LO]].d
+; VBITS_EQ_256-NEXT: fadda d0, [[PG]], d0, [[HI]].d
+; VBITS_EQ_256-NEXT: ret
   %op = load <8 x double>, <8 x double>* %a
   %res = call double @llvm.vector.reduce.fadd.v8f64(double %start, <8 x double> %op)
   ret double %res

diff  --git a/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll b/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll
index 7b957cb11b6e..683ac19ccc5b 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll
@@ -108,9 +108,9 @@ define float @test_v3f32(<3 x float> %a, float %s) nounwind {
 define float @test_v3f32_neutral(<3 x float> %a) nounwind {
 ; CHECK-LABEL: test_v3f32_neutral:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    faddp s1, v0.2s
-; CHECK-NEXT:    mov s0, v0.s[2]
-; CHECK-NEXT:    fadd s0, s1, s0
+; CHECK-NEXT:    mov s1, v0.s[2]
+; CHECK-NEXT:    faddp s0, v0.2s
+; CHECK-NEXT:    fadd s0, s0, s1
 ; CHECK-NEXT:    ret
   %b = call float @llvm.vector.reduce.fadd.f32.v3f32(float -0.0, <3 x float> %a)
   ret float %b
@@ -173,34 +173,34 @@ define fp128 @test_v2f128_neutral(<2 x fp128> %a) nounwind {
 define float @test_v16f32(<16 x float> %a, float %s) nounwind {
 ; CHECK-LABEL: test_v16f32:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fadd s4, s4, s0
-; CHECK-NEXT:    mov s5, v0.s[1]
-; CHECK-NEXT:    fadd s4, s4, s5
-; CHECK-NEXT:    mov s5, v0.s[2]
-; CHECK-NEXT:    mov s0, v0.s[3]
-; CHECK-NEXT:    fadd s4, s4, s5
+; CHECK-NEXT:    mov s22, v0.s[3]
+; CHECK-NEXT:    mov s23, v0.s[2]
+; CHECK-NEXT:    mov s24, v0.s[1]
 ; CHECK-NEXT:    fadd s0, s4, s0
-; CHECK-NEXT:    mov s5, v1.s[1]
-; CHECK-NEXT:    fadd s0, s0, s1
-; CHECK-NEXT:    mov s4, v1.s[2]
-; CHECK-NEXT:    fadd s0, s0, s5
-; CHECK-NEXT:    mov s1, v1.s[3]
-; CHECK-NEXT:    fadd s0, s0, s4
+; CHECK-NEXT:    fadd s0, s0, s24
+; CHECK-NEXT:    fadd s0, s0, s23
+; CHECK-NEXT:    fadd s0, s0, s22
+; CHECK-NEXT:    mov s21, v1.s[1]
 ; CHECK-NEXT:    fadd s0, s0, s1
-; CHECK-NEXT:    mov s5, v2.s[1]
+; CHECK-NEXT:    mov s20, v1.s[2]
+; CHECK-NEXT:    fadd s0, s0, s21
+; CHECK-NEXT:    mov s19, v1.s[3]
+; CHECK-NEXT:    fadd s0, s0, s20
+; CHECK-NEXT:    fadd s0, s0, s19
+; CHECK-NEXT:    mov s18, v2.s[1]
 ; CHECK-NEXT:    fadd s0, s0, s2
-; CHECK-NEXT:    mov s4, v2.s[2]
-; CHECK-NEXT:    fadd s0, s0, s5
-; CHECK-NEXT:    mov s1, v2.s[3]
-; CHECK-NEXT:    fadd s0, s0, s4
-; CHECK-NEXT:    fadd s0, s0, s1
-; CHECK-NEXT:    mov s2, v3.s[1]
+; CHECK-NEXT:    mov s17, v2.s[2]
+; CHECK-NEXT:    fadd s0, s0, s18
+; CHECK-NEXT:    mov s16, v2.s[3]
+; CHECK-NEXT:    fadd s0, s0, s17
+; CHECK-NEXT:    fadd s0, s0, s16
+; CHECK-NEXT:    mov s7, v3.s[1]
 ; CHECK-NEXT:    fadd s0, s0, s3
-; CHECK-NEXT:    mov s5, v3.s[2]
-; CHECK-NEXT:    fadd s0, s0, s2
+; CHECK-NEXT:    mov s6, v3.s[2]
+; CHECK-NEXT:    fadd s0, s0, s7
+; CHECK-NEXT:    mov s5, v3.s[3]
+; CHECK-NEXT:    fadd s0, s0, s6
 ; CHECK-NEXT:    fadd s0, s0, s5
-; CHECK-NEXT:    mov s1, v3.s[3]
-; CHECK-NEXT:    fadd s0, s0, s1
 ; CHECK-NEXT:    ret
   %b = call float @llvm.vector.reduce.fadd.f32.v16f32(float %s, <16 x float> %a)
   ret float %b
@@ -209,32 +209,32 @@ define float @test_v16f32(<16 x float> %a, float %s) nounwind {
 define float @test_v16f32_neutral(<16 x float> %a) nounwind {
 ; CHECK-LABEL: test_v16f32_neutral:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    faddp s4, v0.2s
-; CHECK-NEXT:    mov s5, v0.s[2]
-; CHECK-NEXT:    mov s0, v0.s[3]
-; CHECK-NEXT:    fadd s4, s4, s5
-; CHECK-NEXT:    fadd s0, s4, s0
-; CHECK-NEXT:    mov s5, v1.s[1]
-; CHECK-NEXT:    fadd s0, s0, s1
-; CHECK-NEXT:    mov s4, v1.s[2]
-; CHECK-NEXT:    fadd s0, s0, s5
-; CHECK-NEXT:    mov s1, v1.s[3]
-; CHECK-NEXT:    fadd s0, s0, s4
+; CHECK-NEXT:    mov s21, v0.s[3]
+; CHECK-NEXT:    mov s22, v0.s[2]
+; CHECK-NEXT:    faddp s0, v0.2s
+; CHECK-NEXT:    fadd s0, s0, s22
+; CHECK-NEXT:    fadd s0, s0, s21
+; CHECK-NEXT:    mov s20, v1.s[1]
 ; CHECK-NEXT:    fadd s0, s0, s1
-; CHECK-NEXT:    mov s5, v2.s[1]
+; CHECK-NEXT:    mov s19, v1.s[2]
+; CHECK-NEXT:    fadd s0, s0, s20
+; CHECK-NEXT:    mov s18, v1.s[3]
+; CHECK-NEXT:    fadd s0, s0, s19
+; CHECK-NEXT:    fadd s0, s0, s18
+; CHECK-NEXT:    mov s17, v2.s[1]
 ; CHECK-NEXT:    fadd s0, s0, s2
-; CHECK-NEXT:    mov s4, v2.s[2]
-; CHECK-NEXT:    fadd s0, s0, s5
-; CHECK-NEXT:    mov s1, v2.s[3]
-; CHECK-NEXT:    fadd s0, s0, s4
-; CHECK-NEXT:    fadd s0, s0, s1
-; CHECK-NEXT:    mov s2, v3.s[1]
+; CHECK-NEXT:    mov s16, v2.s[2]
+; CHECK-NEXT:    fadd s0, s0, s17
+; CHECK-NEXT:    mov s7, v2.s[3]
+; CHECK-NEXT:    fadd s0, s0, s16
+; CHECK-NEXT:    fadd s0, s0, s7
+; CHECK-NEXT:    mov s6, v3.s[1]
 ; CHECK-NEXT:    fadd s0, s0, s3
 ; CHECK-NEXT:    mov s5, v3.s[2]
-; CHECK-NEXT:    fadd s0, s0, s2
+; CHECK-NEXT:    fadd s0, s0, s6
+; CHECK-NEXT:    mov s4, v3.s[3]
 ; CHECK-NEXT:    fadd s0, s0, s5
-; CHECK-NEXT:    mov s1, v3.s[3]
-; CHECK-NEXT:    fadd s0, s0, s1
+; CHECK-NEXT:    fadd s0, s0, s4
 ; CHECK-NEXT:    ret
   %b = call float @llvm.vector.reduce.fadd.f32.v16f32(float -0.0, <16 x float> %a)
   ret float %b


        


More information about the llvm-commits mailing list