[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 11 06:44:25 PST 2024


https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/117185

>From 54512e50e0cd8537347eeda4b39f1bdc3f9848f3 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 21 Nov 2024 16:19:19 +0000
Subject: [PATCH 1/7] [AArch64][SVE] Add partial reduction SDNodes

Add the opcode 'ISD::PARTIAL_REDUCE_ADD' and use it when making
SDNodes. When the inputs and outputs have types that can allow for
lowering to wide add or dot product instruction(s), then convert
the corresponding intrinsic to an SDNode. This will allow
legalisation, which will be added in a future patch, to be done
more easily.
---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |  5 ++
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  5 ++
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 16 ++++++
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp |  6 ++
 .../SelectionDAG/SelectionDAGBuilder.cpp      | 12 +++-
 .../SelectionDAG/SelectionDAGBuilder.h        |  1 +
 .../SelectionDAG/SelectionDAGDumper.cpp       |  3 +
 .../Target/AArch64/AArch64ISelLowering.cpp    | 56 +++++++++----------
 8 files changed, 74 insertions(+), 30 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 0b6d155b6d161e..7809a1b26dd7cd 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,6 +1451,11 @@ enum NodeType {
   VECREDUCE_UMAX,
   VECREDUCE_UMIN,
 
+  // The `llvm.experimental.vector.partial.reduce.add` intrinsic
+  // Operands: Accumulator, Input
+  // Outputs: Output
+  PARTIAL_REDUCE_ADD,
+
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
   // Outputs: output chain, glue
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index e97e01839f73b4..c94ca74617e90f 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1599,6 +1599,11 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
+  /// Get a partial reduction SD node for the DAG. This is done when the input
+  /// and output types can be legalised for wide add(s) or dot product(s)
+  SDValue getPartialReduceAddSDNode(SDLoc DL, SDValue Chain, SDValue Acc,
+                                    SDValue Input);
+
   /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
   /// its operands and ReducedTY is the intrinsic's return type.
   SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 61f3c6329efce8..fe8e68e9ef5204 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3010,6 +3010,22 @@ class MaskedHistogramSDNode : public MaskedGatherScatterSDNode {
   }
 };
 
+class PartialReduceAddSDNode : public SDNode {
+public:
+  friend class SelectionDAG;
+
+  PartialReduceAddSDNode(const DebugLoc &dl, SDVTList VTs)
+      : SDNode(ISD::PARTIAL_REDUCE_ADD, 0, dl, VTs) {}
+
+  const SDValue &getChain() const { return getOperand(0); }
+  const SDValue &getAcc() const { return getOperand(1); }
+  const SDValue &getInput() const { return getOperand(2); }
+
+  static bool classof(const SDNode *N) {
+    return N->getOpcode() == ISD::PARTIAL_REDUCE_ADD;
+  }
+};
+
 class FPStateAccessSDNode : public MemSDNode {
 public:
   friend class SelectionDAG;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 34214550f3a12b..5152f0e4ddd6da 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2445,6 +2445,12 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
+SDValue SelectionDAG::getPartialReduceAddSDNode(SDLoc DL, SDValue Chain,
+                                                SDValue Acc, SDValue Input) {
+  return getNode(ISD::PARTIAL_REDUCE_ADD, DL, Acc.getValueType(), Chain, Acc,
+                 Input);
+}
+
 SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
                                           SDValue Op2) {
   EVT FullTy = Op2.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index b72c5eff22f183..a60b38768cd0c3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6427,6 +6427,16 @@ void SelectionDAGBuilder::visitVectorHistogram(const CallInst &I,
   DAG.setRoot(Histogram);
 }
 
+void SelectionDAGBuilder::visitPartialReduceAdd(const CallInst &I,
+                                                unsigned IntrinsicID) {
+  SDLoc dl = getCurSDLoc();
+  SDValue Acc = getValue(I.getOperand(0));
+  SDValue Input = getValue(I.getOperand(1));
+  SDValue Chain = getRoot();
+
+  setValue(&I, DAG.getPartialReduceAddSDNode(dl, Chain, Acc, Input));
+}
+
 void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
                                                        unsigned Intrinsic) {
   assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
@@ -8142,7 +8152,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
   case Intrinsic::experimental_vector_partial_reduce_add: {
 
     if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
-      visitTargetIntrinsic(I, Intrinsic);
+      visitPartialReduceAdd(I, Intrinsic);
       return;
     }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
index 3a8dc25e98700e..a9e0c8f1ea10c1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
@@ -629,6 +629,7 @@ class SelectionDAGBuilder {
   void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI);
   void visitConvergenceControl(const CallInst &I, unsigned Intrinsic);
   void visitVectorHistogram(const CallInst &I, unsigned IntrinsicID);
+  void visitPartialReduceAdd(const CallInst &, unsigned IntrinsicID);
   void visitVectorExtractLastActive(const CallInst &I, unsigned Intrinsic);
   void visitVPLoad(const VPIntrinsic &VPIntrin, EVT VT,
                    const SmallVectorImpl<SDValue> &OpValues);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 580ff19065557b..8ce03b14bda46c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -567,6 +567,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return "histogram";
 
+  case ISD::PARTIAL_REDUCE_ADD:
+    return "partial_reduce_add";
+
     // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...)                    \
   case ISD::SDID:                                                              \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d1354ccf376609..5e61c6f36cb47b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1124,6 +1124,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine(
       {ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
 
+  setTargetDAGCombine(ISD::PARTIAL_REDUCE_ADD);
+
   setTargetDAGCombine(ISD::FP_EXTEND);
 
   setTargetDAGCombine(ISD::GlobalAddress);
@@ -21721,26 +21723,21 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(SDNode *N,
+SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
                                       const AArch64Subtarget *Subtarget,
                                       SelectionDAG &DAG) {
 
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
-  bool Scalable = N->getValueType(0).isScalableVector();
+  bool Scalable = PR->getValueType(0).isScalableVector();
   if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
   if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
     return SDValue();
 
-  SDLoc DL(N);
+  SDLoc DL(PR);
 
   // The narrower of the two operands. Used as the accumulator
-  auto NarrowOp = N->getOperand(1);
-  auto MulOp = N->getOperand(2);
+  auto NarrowOp = PR->getAcc();
+  auto MulOp = PR->getInput();
   if (MulOp->getOpcode() != ISD::MUL)
     return SDValue();
 
@@ -21758,7 +21755,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   if (A.getValueType() != B.getValueType())
     return SDValue();
 
-  EVT ReducedType = N->getValueType(0);
+  EVT ReducedType = PR->getValueType(0);
   EVT MulSrcType = A.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
@@ -21777,7 +21774,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
-    bool Scalable = N->getValueType(0).isScalableVT();
+    bool Scalable = PR->getValueType(0).isScalableVT();
     // There's no nxv2i64 version of usdot
     if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
       return SDValue();
@@ -21808,22 +21805,17 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
 }
 
-SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
+SDValue tryLowerPartialReductionToWideAdd(PartialReduceAddSDNode *PR,
                                           const AArch64Subtarget *Subtarget,
                                           SelectionDAG &DAG) {
 
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
   if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
     return SDValue();
 
-  SDLoc DL(N);
+  SDLoc DL(PR);
 
-  auto Acc = N->getOperand(1);
-  auto ExtInput = N->getOperand(2);
+  auto Acc = PR->getAcc();
+  auto ExtInput = PR->getInput();
 
   EVT AccVT = Acc.getValueType();
   EVT AccElemVT = AccVT.getVectorElementType();
@@ -21850,6 +21842,18 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
   return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
 }
 
+static SDValue
+performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
+                               const AArch64Subtarget *Subtarget) {
+  auto *PR = cast<PartialReduceAddSDNode>(N);
+  if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
+    return Dot;
+  if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget, DAG))
+    return WideAdd;
+  return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
+                                 PR->getInput());
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -21858,14 +21862,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   switch (IID) {
   default:
     break;
-  case Intrinsic::experimental_vector_partial_reduce_add: {
-    if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
-      return Dot;
-    if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
-      return WideAdd;
-    return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
-                                   N->getOperand(1), N->getOperand(2));
-  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
@@ -26156,6 +26152,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::MSCATTER:
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return performMaskedGatherScatterCombine(N, DCI, DAG);
+  case ISD::PARTIAL_REDUCE_ADD:
+    return performPartialReduceAddCombine(N, DAG, Subtarget);
   case ISD::FP_EXTEND:
     return performFPExtendCombine(N, DAG, DCI, Subtarget);
   case AArch64ISD::BRCOND:

>From 14dd3e3d257cd3674f0046e28b2605a1818efbb0 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Wed, 27 Nov 2024 09:48:06 +0000
Subject: [PATCH 2/7] Changes to previous patch. Involves removing
 PartialReduceAddSDNode as well as changing how the intrinsic is transformed
 into the SD node.

---
 llvm/include/llvm/CodeGen/SelectionDAG.h      | 13 +++-----
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 16 ----------
 .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 10 ++----
 .../SelectionDAG/SelectionDAGBuilder.cpp      | 21 ++++---------
 .../SelectionDAG/SelectionDAGBuilder.h        |  1 -
 .../Target/AArch64/AArch64ISelLowering.cpp    | 31 +++++++++----------
 6 files changed, 27 insertions(+), 65 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index c94ca74617e90f..b361e5d33b8915 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1599,15 +1599,10 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
-  /// Get a partial reduction SD node for the DAG. This is done when the input
-  /// and output types can be legalised for wide add(s) or dot product(s)
-  SDValue getPartialReduceAddSDNode(SDLoc DL, SDValue Chain, SDValue Acc,
-                                    SDValue Input);
-
-  /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
-  /// its operands and ReducedTY is the intrinsic's return type.
-  SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
-                              SDValue Op2);
+  /// Expands partial reduce node which can't be lowered to wide add or dot
+  /// product instruction(s)
+  SDValue expandPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
+                                 SDValue Op2);
 
   /// Expands a node with multiple results to an FP or vector libcall. The
   /// libcall is expected to take all the operands of the \p Node followed by
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index fe8e68e9ef5204..61f3c6329efce8 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3010,22 +3010,6 @@ class MaskedHistogramSDNode : public MaskedGatherScatterSDNode {
   }
 };
 
-class PartialReduceAddSDNode : public SDNode {
-public:
-  friend class SelectionDAG;
-
-  PartialReduceAddSDNode(const DebugLoc &dl, SDVTList VTs)
-      : SDNode(ISD::PARTIAL_REDUCE_ADD, 0, dl, VTs) {}
-
-  const SDValue &getChain() const { return getOperand(0); }
-  const SDValue &getAcc() const { return getOperand(1); }
-  const SDValue &getInput() const { return getOperand(2); }
-
-  static bool classof(const SDNode *N) {
-    return N->getOpcode() == ISD::PARTIAL_REDUCE_ADD;
-  }
-};
-
 class FPStateAccessSDNode : public MemSDNode {
 public:
   friend class SelectionDAG;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 5152f0e4ddd6da..0f01964de76630 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2445,14 +2445,8 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
-SDValue SelectionDAG::getPartialReduceAddSDNode(SDLoc DL, SDValue Chain,
-                                                SDValue Acc, SDValue Input) {
-  return getNode(ISD::PARTIAL_REDUCE_ADD, DL, Acc.getValueType(), Chain, Acc,
-                 Input);
-}
-
-SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
-                                          SDValue Op2) {
+SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, EVT ReducedTy,
+                                             SDValue Op1, SDValue Op2) {
   EVT FullTy = Op2.getValueType();
 
   unsigned Stride = ReducedTy.getVectorMinNumElements();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index a60b38768cd0c3..71ee37df8ac895 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6427,16 +6427,6 @@ void SelectionDAGBuilder::visitVectorHistogram(const CallInst &I,
   DAG.setRoot(Histogram);
 }
 
-void SelectionDAGBuilder::visitPartialReduceAdd(const CallInst &I,
-                                                unsigned IntrinsicID) {
-  SDLoc dl = getCurSDLoc();
-  SDValue Acc = getValue(I.getOperand(0));
-  SDValue Input = getValue(I.getOperand(1));
-  SDValue Chain = getRoot();
-
-  setValue(&I, DAG.getPartialReduceAddSDNode(dl, Chain, Acc, Input));
-}
-
 void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
                                                        unsigned Intrinsic) {
   assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
@@ -8150,15 +8140,16 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     return;
   }
   case Intrinsic::experimental_vector_partial_reduce_add: {
+    SDLoc dl = getCurSDLoc();
+    SDValue Acc = getValue(I.getOperand(0));
+    EVT AccVT = Acc.getValueType();
+    SDValue Input = getValue(I.getOperand(1));
 
     if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
-      visitPartialReduceAdd(I, Intrinsic);
+      setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_ADD, dl, AccVT, Acc, Input));
       return;
     }
-
-    setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
-                                         getValue(I.getOperand(0)),
-                                         getValue(I.getOperand(1))));
+    setValue(&I, DAG.expandPartialReduceAdd(dl, AccVT, Acc, Input));
     return;
   }
   case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
index a9e0c8f1ea10c1..3a8dc25e98700e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
@@ -629,7 +629,6 @@ class SelectionDAGBuilder {
   void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI);
   void visitConvergenceControl(const CallInst &I, unsigned Intrinsic);
   void visitVectorHistogram(const CallInst &I, unsigned IntrinsicID);
-  void visitPartialReduceAdd(const CallInst &, unsigned IntrinsicID);
   void visitVectorExtractLastActive(const CallInst &I, unsigned Intrinsic);
   void visitVPLoad(const VPIntrinsic &VPIntrin, EVT VT,
                    const SmallVectorImpl<SDValue> &OpValues);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5e61c6f36cb47b..aa14ea9d25ca99 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21723,21 +21723,21 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
+SDValue tryLowerPartialReductionToDot(SDNode *N,
                                       const AArch64Subtarget *Subtarget,
                                       SelectionDAG &DAG) {
 
-  bool Scalable = PR->getValueType(0).isScalableVector();
+  bool Scalable = N->getValueType(0).isScalableVector();
   if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
   if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
     return SDValue();
 
-  SDLoc DL(PR);
+  SDLoc DL(N);
 
   // The narrower of the two operands. Used as the accumulator
-  auto NarrowOp = PR->getAcc();
-  auto MulOp = PR->getInput();
+  auto NarrowOp = N->getOperand(0);
+  auto MulOp = N->getOperand(1);
   if (MulOp->getOpcode() != ISD::MUL)
     return SDValue();
 
@@ -21755,7 +21755,7 @@ SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
   if (A.getValueType() != B.getValueType())
     return SDValue();
 
-  EVT ReducedType = PR->getValueType(0);
+  EVT ReducedType = N->getValueType(0);
   EVT MulSrcType = A.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
@@ -21774,7 +21774,7 @@ SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
-    bool Scalable = PR->getValueType(0).isScalableVT();
+    bool Scalable = N->getValueType(0).isScalableVT();
     // There's no nxv2i64 version of usdot
     if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
       return SDValue();
@@ -21805,17 +21805,17 @@ SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
   return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
 }
 
-SDValue tryLowerPartialReductionToWideAdd(PartialReduceAddSDNode *PR,
+SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
                                           const AArch64Subtarget *Subtarget,
                                           SelectionDAG &DAG) {
 
   if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
     return SDValue();
 
-  SDLoc DL(PR);
+  SDLoc DL(N);
 
-  auto Acc = PR->getAcc();
-  auto ExtInput = PR->getInput();
+  auto Acc = N->getOperand(0);
+  auto ExtInput = N->getOperand(1);
 
   EVT AccVT = Acc.getValueType();
   EVT AccElemVT = AccVT.getVectorElementType();
@@ -21845,13 +21845,12 @@ SDValue tryLowerPartialReductionToWideAdd(PartialReduceAddSDNode *PR,
 static SDValue
 performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
                                const AArch64Subtarget *Subtarget) {
-  auto *PR = cast<PartialReduceAddSDNode>(N);
-  if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
+  if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
     return Dot;
-  if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget, DAG))
+  if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
     return WideAdd;
-  return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
-                                 PR->getInput());
+  return DAG.expandPartialReduceAdd(SDLoc(N), N->getValueType(0),
+                                    N->getOperand(0), N->getOperand(1));
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,

>From 34e81dd7fe9801cd301613bf191b5fb88c5ec936 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 28 Nov 2024 12:11:36 +0000
Subject: [PATCH 3/7] Remove unnecessary function parameter and update comments

---
 llvm/include/llvm/CodeGen/ISDOpcodes.h                | 3 ++-
 llvm/include/llvm/CodeGen/SelectionDAG.h              | 9 +++++----
 llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp        | 5 +++--
 llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 2 +-
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp       | 3 +--
 5 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 7809a1b26dd7cd..63a9e322400753 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,7 +1451,8 @@ enum NodeType {
   VECREDUCE_UMAX,
   VECREDUCE_UMIN,
 
-  // The `llvm.experimental.vector.partial.reduce.add` intrinsic
+  // This corresponds to the `llvm.experimental.vector.partial.reduce.add`
+  // intrinsic
   // Operands: Accumulator, Input
   // Outputs: Output
   PARTIAL_REDUCE_ADD,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index b361e5d33b8915..7d7637a2d4df6e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1599,10 +1599,11 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
-  /// Expands partial reduce node which can't be lowered to wide add or dot
-  /// product instruction(s)
-  SDValue expandPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
-                                 SDValue Op2);
+  /// Expands PARTIAL_REDUCE_ADD nodes which can't be lowered.
+  /// @param Op1 Accumulator for where the result is stored for the partial
+  /// reduction operation
+  /// @param Op2 Input for the partial reduction operation
+  SDValue expandPartialReduceAdd(SDLoc DL, SDValue Op1, SDValue Op2);
 
   /// Expands a node with multiple results to an FP or vector libcall. The
   /// libcall is expected to take all the operands of the \p Node followed by
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 0f01964de76630..611835c8cc53bb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2445,8 +2445,9 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
-SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, EVT ReducedTy,
-                                             SDValue Op1, SDValue Op2) {
+SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, SDValue Op1,
+                                             SDValue Op2) {
+  EVT ReducedTy = Op1.getValueType();
   EVT FullTy = Op2.getValueType();
 
   unsigned Stride = ReducedTy.getVectorMinNumElements();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 71ee37df8ac895..948fea268d18fb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8149,7 +8149,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
       setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_ADD, dl, AccVT, Acc, Input));
       return;
     }
-    setValue(&I, DAG.expandPartialReduceAdd(dl, AccVT, Acc, Input));
+    setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));
     return;
   }
   case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index aa14ea9d25ca99..badf61f662db1e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21849,8 +21849,7 @@ performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
     return Dot;
   if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
     return WideAdd;
-  return DAG.expandPartialReduceAdd(SDLoc(N), N->getValueType(0),
-                                    N->getOperand(0), N->getOperand(1));
+  return DAG.expandPartialReduceAdd(SDLoc(N), N->getOperand(0), N->getOperand(1));
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,

>From 32e5549d59ab23d81af7674b01f4b8cda36ee630 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 28 Nov 2024 13:17:15 +0000
Subject: [PATCH 4/7] Code formatting changes

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index badf61f662db1e..152d3631c85970 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21849,7 +21849,8 @@ performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
     return Dot;
   if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
     return WideAdd;
-  return DAG.expandPartialReduceAdd(SDLoc(N), N->getOperand(0), N->getOperand(1));
+  return DAG.expandPartialReduceAdd(SDLoc(N), N->getOperand(0),
+                                    N->getOperand(1));
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,

>From 9014cf30c83458ef40fe66462c884a32e0c60c25 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 2 Dec 2024 09:41:36 +0000
Subject: [PATCH 5/7] Make two ISD nodes for partial reductions as opposed to
 one

---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |  5 ++-
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  9 ++--
 llvm/include/llvm/CodeGen/TargetLowering.h    | 10 +++++
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  7 +++-
 .../SelectionDAG/SelectionDAGDumper.cpp       |  6 ++-
 .../Target/AArch64/AArch64ISelLowering.cpp    | 42 +++++++++++--------
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  2 +
 7 files changed, 55 insertions(+), 26 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 63a9e322400753..aac5ff6b371e81 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,11 +1451,12 @@ enum NodeType {
   VECREDUCE_UMAX,
   VECREDUCE_UMIN,
 
-  // This corresponds to the `llvm.experimental.vector.partial.reduce.add`
+  // These correspond to the `llvm.experimental.vector.partial.reduce.add`
   // intrinsic
   // Operands: Accumulator, Input
   // Outputs: Output
-  PARTIAL_REDUCE_ADD,
+  PARTIAL_REDUCE_SADD,
+  PARTIAL_REDUCE_UADD,
 
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 7d7637a2d4df6e..acc08f2df234d0 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1599,10 +1599,11 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
-  /// Expands PARTIAL_REDUCE_ADD nodes which can't be lowered.
-  /// @param Op1 Accumulator for where the result is stored for the partial
-  /// reduction operation
-  /// @param Op2 Input for the partial reduction operation
+  /// Expands PARTIAL_REDUCE_S/UADD nodes to a sequence of subvector extracts
+  /// followed by vector adds.
+  /// \p Op1 Accumulator for where the result is stored for the partial
+  /// reduction operation.
+  /// \p Op2 Input for the partial reduction operation.
   SDValue expandPartialReduceAdd(SDLoc DL, SDValue Op1, SDValue Op2);
 
   /// Expands a node with multiple results to an FP or vector libcall. The
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index a207f3886bd0e8..caf66386237a4b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -462,6 +462,16 @@ class TargetLoweringBase {
     return true;
   }
 
+  /// Return true if there is a sign extend on the input to this function. Used
+  /// to determine whether to transform the
+  /// @llvm.experimental.vector.partial.reduce.* intrinsic to
+  /// PARTIAL_REDUCE_SADD or PARTIAL_REDUCE_UADD. It also removes the extend
+  /// from the input. \p Input The 'Input' operand to the
+  /// @llvm.experimental.vector.partial.reduce.* intrinsic.
+  virtual bool isPartialReductionInputSigned(SDValue &Input) const {
+    return false;
+  }
+
   /// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
   /// using generic code in SelectionDAGBuilder.
   virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 948fea268d18fb..ec649c7feced08 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8146,7 +8146,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     SDValue Input = getValue(I.getOperand(1));
 
     if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
-      setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_ADD, dl, AccVT, Acc, Input));
+      if (TLI.isPartialReductionInputSigned(Input))
+        setValue(&I,
+                 DAG.getNode(ISD::PARTIAL_REDUCE_SADD, dl, AccVT, Acc, Input));
+      else
+        setValue(&I,
+                 DAG.getNode(ISD::PARTIAL_REDUCE_UADD, dl, AccVT, Acc, Input));
       return;
     }
     setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 8ce03b14bda46c..1a710a47095189 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -567,8 +567,10 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return "histogram";
 
-  case ISD::PARTIAL_REDUCE_ADD:
-    return "partial_reduce_add";
+  case ISD::PARTIAL_REDUCE_UADD:
+    return "partial_reduce_uadd";
+  case ISD::PARTIAL_REDUCE_SADD:
+    return "partial_reduce_sadd";
 
     // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...)                    \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 152d3631c85970..2543025f825ad4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1124,7 +1124,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine(
       {ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
 
-  setTargetDAGCombine(ISD::PARTIAL_REDUCE_ADD);
+  setTargetDAGCombine({ISD::PARTIAL_REDUCE_SADD, ISD::PARTIAL_REDUCE_UADD});
 
   setTargetDAGCombine(ISD::FP_EXTEND);
 
@@ -2043,13 +2043,30 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   EVT VT = EVT::getEVT(I->getType());
   auto Op1 = I->getOperand(1);
   EVT Op1VT = EVT::getEVT(Op1->getType());
-  if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
-      (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() ||
-       VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()))
+  if ((Op1VT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
+      (Op1VT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
+      (Op1VT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
+      (Op1VT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
+      (Op1VT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
+      (Op1VT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
+      (Op1VT == MVT::v16i64 && VT == MVT::v4i64) ||
+      (Op1VT == MVT::v16i32 && VT == MVT::v4i32) ||
+      (Op1VT == MVT::v8i32 && VT == MVT::v2i32))
     return false;
   return true;
 }
 
+bool AArch64TargetLowering::isPartialReductionInputSigned(
+    SDValue &Input) const {
+  unsigned InputOpcode = Input.getOpcode();
+  if (ISD::isExtOpcode(InputOpcode)) {
+    Input = Input.getOperand(0);
+    if (InputOpcode == ISD::SIGN_EXTEND)
+      return true;
+  }
+  return false;
+}
+
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
   if (!Subtarget->isSVEorStreamingSVEAvailable())
     return true;
@@ -21815,19 +21832,9 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
   SDLoc DL(N);
 
   auto Acc = N->getOperand(0);
-  auto ExtInput = N->getOperand(1);
+  auto Input = N->getOperand(1);
 
   EVT AccVT = Acc.getValueType();
-  EVT AccElemVT = AccVT.getVectorElementType();
-
-  if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
-    return SDValue();
-
-  unsigned ExtInputOpcode = ExtInput->getOpcode();
-  if (!ISD::isExtOpcode(ExtInputOpcode))
-    return SDValue();
-
-  auto Input = ExtInput->getOperand(0);
   EVT InputVT = Input.getValueType();
 
   if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
@@ -21835,7 +21842,7 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
       !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
     return SDValue();
 
-  bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
+  bool InputIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SADD;
   auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
   auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
   auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
@@ -26151,7 +26158,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::MSCATTER:
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return performMaskedGatherScatterCombine(N, DCI, DAG);
-  case ISD::PARTIAL_REDUCE_ADD:
+  case ISD::PARTIAL_REDUCE_UADD:
+  case ISD::PARTIAL_REDUCE_SADD:
     return performPartialReduceAddCombine(N, DAG, Subtarget);
   case ISD::FP_EXTEND:
     return performFPExtendCombine(N, DAG, DCI, Subtarget);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index cb0b9e965277aa..be0b4e5982f233 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -983,6 +983,8 @@ class AArch64TargetLowering : public TargetLowering {
   bool
   shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
 
+  bool isPartialReductionInputSigned(SDValue &Input) const override;
+
   bool shouldExpandCttzElements(EVT VT) const override;
 
   bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const override;

>From 649fc6bec83105d69bbb5d03500daaa8b150db0f Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 9 Dec 2024 13:41:41 +0000
Subject: [PATCH 6/7] Determine which ISD node to use in DAG combine rather
 than in SelectionDAGBuilder.

---
 llvm/include/llvm/CodeGen/ISDOpcodes.h        |  4 +-
 llvm/include/llvm/CodeGen/SelectionDAG.h      |  3 +-
 llvm/include/llvm/CodeGen/TargetLowering.h    | 10 ---
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  8 +-
 .../Target/AArch64/AArch64ISelLowering.cpp    | 90 +++++++++----------
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  2 -
 6 files changed, 49 insertions(+), 68 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index aac5ff6b371e81..51433001368571 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,8 +1451,8 @@ enum NodeType {
   VECREDUCE_UMAX,
   VECREDUCE_UMIN,
 
-  // These correspond to the `llvm.experimental.vector.partial.reduce.add`
-  // intrinsic
+  // Nodes used to represent a partial reduction addition operation (signed and
+  // unsigned).
   // Operands: Accumulator, Input
   // Outputs: Output
   PARTIAL_REDUCE_SADD,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index acc08f2df234d0..80ba06ad39d2a1 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1599,8 +1599,7 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
-  /// Expands PARTIAL_REDUCE_S/UADD nodes to a sequence of subvector extracts
-  /// followed by vector adds.
+  /// Expands PARTIAL_REDUCE_S/UADD nodes.
   /// \p Op1 Accumulator for where the result is stored for the partial
   /// reduction operation.
   /// \p Op2 Input for the partial reduction operation.
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index caf66386237a4b..a207f3886bd0e8 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -462,16 +462,6 @@ class TargetLoweringBase {
     return true;
   }
 
-  /// Return true if there is a sign extend on the input to this function. Used
-  /// to determine whether to transform the
-  /// @llvm.experimental.vector.partial.reduce.* intrinsic to
-  /// PARTIAL_REDUCE_SADD or PARTIAL_REDUCE_UADD. It also removes the extend
-  /// from the input. \p Input The 'Input' operand to the
-  /// @llvm.experimental.vector.partial.reduce.* intrinsic.
-  virtual bool isPartialReductionInputSigned(SDValue &Input) const {
-    return false;
-  }
-
   /// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
   /// using generic code in SelectionDAGBuilder.
   virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index ec649c7feced08..06d7d09761b8f9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8146,12 +8146,8 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     SDValue Input = getValue(I.getOperand(1));
 
     if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
-      if (TLI.isPartialReductionInputSigned(Input))
-        setValue(&I,
-                 DAG.getNode(ISD::PARTIAL_REDUCE_SADD, dl, AccVT, Acc, Input));
-      else
-        setValue(&I,
-                 DAG.getNode(ISD::PARTIAL_REDUCE_UADD, dl, AccVT, Acc, Input));
+      setValue(&I,
+               DAG.getNode(ISD::PARTIAL_REDUCE_UADD, dl, AccVT, Acc, Input));
       return;
     }
     setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2543025f825ad4..4f2ccbc95f9404 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2056,17 +2056,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
   return true;
 }
 
-bool AArch64TargetLowering::isPartialReductionInputSigned(
-    SDValue &Input) const {
-  unsigned InputOpcode = Input.getOpcode();
-  if (ISD::isExtOpcode(InputOpcode)) {
-    Input = Input.getOperand(0);
-    if (InputOpcode == ISD::SIGN_EXTEND)
-      return true;
-  }
-  return false;
-}
-
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
   if (!Subtarget->isSVEorStreamingSVEAvailable())
     return true;
@@ -21758,22 +21747,43 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   if (MulOp->getOpcode() != ISD::MUL)
     return SDValue();
 
-  auto ExtA = MulOp->getOperand(0);
-  auto ExtB = MulOp->getOperand(1);
+  auto A = MulOp->getOperand(0);
+  auto B = MulOp->getOperand(1);
 
-  if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
-      !ISD::isExtOpcode(ExtB->getOpcode()))
-    return SDValue();
-  bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
-  bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+  unsigned AOpcode = A->getOpcode();
+  unsigned BOpcode = B->getOpcode();
+  unsigned Opcode;
+  EVT ReducedType = N->getValueType(0);
+  EVT MulSrcType;
+  if (ISD::isExtOpcode(AOpcode) || ISD::isExtOpcode(BOpcode)) {
+    bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
+    bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
+
+    A = A->getOperand(0);
+    B = B->getOperand(0);
+    if (A.getValueType() != B.getValueType())
+      return SDValue();
 
-  auto A = ExtA->getOperand(0);
-  auto B = ExtB->getOperand(0);
-  if (A.getValueType() != B.getValueType())
-    return SDValue();
+    if (AIsSigned != BIsSigned) {
+      if (!Subtarget->hasMatMulInt8())
+        return SDValue();
 
-  EVT ReducedType = N->getValueType(0);
-  EVT MulSrcType = A.getValueType();
+      bool Scalable = N->getValueType(0).isScalableVT();
+      // There's no nxv2i64 version of usdot
+      if (Scalable && ReducedType != MVT::nxv4i32 &&
+          ReducedType != MVT::nxv4i64)
+        return SDValue();
+
+      Opcode = AArch64ISD::USDOT;
+      // USDOT expects the signed operand to be last
+      if (!BIsSigned)
+        std::swap(A, B);
+    } else if (AIsSigned)
+      Opcode = AArch64ISD::SDOT;
+    else
+      Opcode = AArch64ISD::UDOT;
+    MulSrcType = A.getValueType();
+  }
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
@@ -21785,26 +21795,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
       !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
     return SDValue();
 
-  // If the extensions are mixed, we should lower it to a usdot instead
-  unsigned Opcode = 0;
-  if (AIsSigned != BIsSigned) {
-    if (!Subtarget->hasMatMulInt8())
-      return SDValue();
-
-    bool Scalable = N->getValueType(0).isScalableVT();
-    // There's no nxv2i64 version of usdot
-    if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
-      return SDValue();
-
-    Opcode = AArch64ISD::USDOT;
-    // USDOT expects the signed operand to be last
-    if (!BIsSigned)
-      std::swap(A, B);
-  } else if (AIsSigned)
-    Opcode = AArch64ISD::SDOT;
-  else
-    Opcode = AArch64ISD::UDOT;
-
   // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
   // product followed by a zero / sign extension
   if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
@@ -21834,15 +21824,23 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
   auto Acc = N->getOperand(0);
   auto Input = N->getOperand(1);
 
-  EVT AccVT = Acc.getValueType();
+  unsigned Opcode = N->getOpcode();
+  unsigned InputOpcode = Input.getOpcode();
+  if (ISD::isExtOpcode(InputOpcode)) {
+    Input = Input.getOperand(0);
+    if (InputOpcode == ISD::SIGN_EXTEND)
+      Opcode = ISD::PARTIAL_REDUCE_SADD;
+  }
+
   EVT InputVT = Input.getValueType();
+  EVT AccVT = Acc.getValueType();
 
   if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
       !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
       !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
     return SDValue();
 
-  bool InputIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SADD;
+  bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
   auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
   auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
   auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index be0b4e5982f233..cb0b9e965277aa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -983,8 +983,6 @@ class AArch64TargetLowering : public TargetLowering {
   bool
   shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
 
-  bool isPartialReductionInputSigned(SDValue &Input) const override;
-
   bool shouldExpandCttzElements(EVT VT) const override;
 
   bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const override;

>From 2aa668cb9301e6c0f00120cef824efed75461e06 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Wed, 11 Dec 2024 14:41:56 +0000
Subject: [PATCH 7/7] Separate lowering code for PARTIAL_REDUCE_U/SADD

Separate lowering code from all being in the DAG-combine function.
Now the DAG-combine decides whether the node should be the signed
or unsigned version of partial reduce add. Then there is a function
in LowerOperation that does the actual lowering to wide adds or dot
products if it is able to.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 249 ++++++++++--------
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |   1 +
 2 files changed, 145 insertions(+), 105 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4f2ccbc95f9404..ee75d4b0371e34 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1841,8 +1841,17 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
                          Custom);
     }
+
+    for (auto VT : {MVT::nxv2i64, MVT::nxv4i32, MVT::nxv8i16}) {
+      setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
+      setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
+    }
   }
 
+  for (auto VT : {MVT::v4i64, MVT::v4i32, MVT::v2i32}) {
+    setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
+    setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
+  }
 
   if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
     // Only required for llvm.aarch64.mops.memset.tag
@@ -2041,17 +2050,18 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
     return true;
 
   EVT VT = EVT::getEVT(I->getType());
-  auto Op1 = I->getOperand(1);
-  EVT Op1VT = EVT::getEVT(Op1->getType());
-  if ((Op1VT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
-      (Op1VT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
-      (Op1VT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
-      (Op1VT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
-      (Op1VT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
-      (Op1VT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
-      (Op1VT == MVT::v16i64 && VT == MVT::v4i64) ||
-      (Op1VT == MVT::v16i32 && VT == MVT::v4i32) ||
-      (Op1VT == MVT::v8i32 && VT == MVT::v2i32))
+  auto Input = I->getOperand(1);
+  EVT InputVT = EVT::getEVT(Input->getType());
+
+  if ((InputVT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
+      (InputVT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
+      (InputVT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
+      (InputVT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
+      (InputVT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
+      (InputVT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
+      (InputVT == MVT::v16i64 && VT == MVT::v4i64) ||
+      (InputVT == MVT::v16i32 && VT == MVT::v4i32) ||
+      (InputVT == MVT::v8i32 && VT == MVT::v2i32))
     return false;
   return true;
 }
@@ -7519,6 +7529,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerFLDEXP(Op, DAG);
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return LowerVECTOR_HISTOGRAM(Op, DAG);
+  case ISD::PARTIAL_REDUCE_UADD:
+  case ISD::PARTIAL_REDUCE_SADD:
+    return LowerPARTIAL_REDUCE_ADD(Op, DAG);
   }
 }
 
@@ -21729,133 +21742,126 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(SDNode *N,
-                                      const AArch64Subtarget *Subtarget,
-                                      SelectionDAG &DAG) {
-
-  bool Scalable = N->getValueType(0).isScalableVector();
+SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
+                               const AArch64Subtarget *Subtarget, SDLoc &DL) {
+  bool Scalable = Acc.getValueType().isScalableVector();
   if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
   if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
     return SDValue();
 
-  SDLoc DL(N);
-
-  // The narrower of the two operands. Used as the accumulator
-  auto NarrowOp = N->getOperand(0);
-  auto MulOp = N->getOperand(1);
-  if (MulOp->getOpcode() != ISD::MUL)
+  unsigned InputOpcode = Input->getOpcode();
+  if (InputOpcode != ISD::MUL)
     return SDValue();
-
-  auto A = MulOp->getOperand(0);
-  auto B = MulOp->getOperand(1);
-
+  auto A = Input->getOperand(0);
+  auto B = Input->getOperand(1);
   unsigned AOpcode = A->getOpcode();
   unsigned BOpcode = B->getOpcode();
-  unsigned Opcode;
-  EVT ReducedType = N->getValueType(0);
-  EVT MulSrcType;
-  if (ISD::isExtOpcode(AOpcode) || ISD::isExtOpcode(BOpcode)) {
-    bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
-    bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
-
-    A = A->getOperand(0);
-    B = B->getOperand(0);
-    if (A.getValueType() != B.getValueType())
-      return SDValue();
+  EVT AccVT = Acc->getValueType(0);
 
-    if (AIsSigned != BIsSigned) {
-      if (!Subtarget->hasMatMulInt8())
-        return SDValue();
+  if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
+    return DAG.expandPartialReduceAdd(DL, Acc, Input);
 
-      bool Scalable = N->getValueType(0).isScalableVT();
-      // There's no nxv2i64 version of usdot
-      if (Scalable && ReducedType != MVT::nxv4i32 &&
-          ReducedType != MVT::nxv4i64)
-        return SDValue();
+  bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
+  bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
 
-      Opcode = AArch64ISD::USDOT;
-      // USDOT expects the signed operand to be last
-      if (!BIsSigned)
-        std::swap(A, B);
-    } else if (AIsSigned)
-      Opcode = AArch64ISD::SDOT;
-    else
-      Opcode = AArch64ISD::UDOT;
-    MulSrcType = A.getValueType();
-  }
+  A = A->getOperand(0);
+  B = B->getOperand(0);
+  EVT MulSrcVT = A.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
-  if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
-      !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
-      !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
-      !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
-      !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
-      !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
-    return SDValue();
+  if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+      !(AccVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+      !(AccVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+      !(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+      !(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+      !(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+    return DAG.expandPartialReduceAdd(DL, Acc, Input);
+
+  unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+  if (AIsSigned != BIsSigned) {
+    if (!Subtarget->hasMatMulInt8())
+      return DAG.expandPartialReduceAdd(DL, Acc, Input);
+
+    bool Scalable = AccVT.isScalableVT();
+    // There's no nxv2i64 version of usdot
+    if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
+      return DAG.expandPartialReduceAdd(DL, Acc, Input);
+
+    if (!BIsSigned)
+      std::swap(A, B);
+    DotOpcode = AArch64ISD::USDOT;
+    // Lower usdot patterns here because legalisation would attempt to split it
+    // unless exts are removed. But, removing the exts would lose the
+    // information about whether each operand is signed.
+    if ((AccVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
+        (AccVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
+      return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
+  }
 
   // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
-  // product followed by a zero / sign extension
-  if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
-      (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
-    EVT ReducedTypeI32 =
-        (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+  // product followed by a zero / sign extension. Need to lower this here
+  // because legalisation would attempt to split it.
+  if ((AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
+      (AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+    EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
 
-    auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
-                              DAG.getConstant(0, DL, ReducedTypeI32), A, B);
-    auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
-    return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
-                       Extended);
+    auto DotI32 = DAG.getNode(DotOpcode, DL, AccVTI32,
+                              DAG.getConstant(0, DL, AccVTI32), A, B);
+    auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
+    return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
   }
 
-  return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
-}
+  if (A.getValueType() != B.getValueType())
+    return DAG.expandPartialReduceAdd(DL, Acc, Input);
 
-SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
-                                          const AArch64Subtarget *Subtarget,
-                                          SelectionDAG &DAG) {
+  unsigned NewOpcode =
+      AIsSigned ? ISD::PARTIAL_REDUCE_SADD : ISD::PARTIAL_REDUCE_UADD;
+  auto NewMul = DAG.getNode(ISD::MUL, DL, A.getValueType(), A, B);
+  return DAG.getNode(NewOpcode, DL, AccVT, Acc, NewMul);
+}
 
+SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
+                            const AArch64Subtarget *Subtarget, SDLoc &DL) {
   if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
-    return SDValue();
-
-  SDLoc DL(N);
-
-  auto Acc = N->getOperand(0);
-  auto Input = N->getOperand(1);
-
-  unsigned Opcode = N->getOpcode();
-  unsigned InputOpcode = Input.getOpcode();
-  if (ISD::isExtOpcode(InputOpcode)) {
-    Input = Input.getOperand(0);
-    if (InputOpcode == ISD::SIGN_EXTEND)
-      Opcode = ISD::PARTIAL_REDUCE_SADD;
-  }
-
+    return DAG.expandPartialReduceAdd(DL, Acc, Input);
+  unsigned InputOpcode = Input->getOpcode();
+  if (!ISD::isExtOpcode(InputOpcode))
+    return DAG.expandPartialReduceAdd(DL, Acc, Input);
+  Input = Input->getOperand(0);
   EVT InputVT = Input.getValueType();
-  EVT AccVT = Acc.getValueType();
+  EVT AccVT = Acc->getValueType(0);
 
   if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
       !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
       !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
     return SDValue();
 
-  bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
-  auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
-  auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
-  auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
-  return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
+  unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
+                           ? ISD::PARTIAL_REDUCE_SADD
+                           : ISD::PARTIAL_REDUCE_UADD;
+  return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input);
 }
 
-static SDValue
-performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
-                               const AArch64Subtarget *Subtarget) {
-  if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
+SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
+                                       const AArch64Subtarget *Subtarget) {
+  SDLoc DL(N);
+  auto Acc = N->getOperand(0);
+  auto Input = N->getOperand(1);
+  EVT AccElemVT = Acc.getValueType().getVectorElementType();
+  EVT InputElemVT = Input.getValueType().getVectorElementType();
+
+  // If the exts have already been removed or it has already been lowered to an
+  // usdot instruction, then the element types will not be equal
+  if (InputElemVT != AccElemVT || Input.getOpcode() == AArch64ISD::USDOT)
+    return SDValue(N, 0);
+
+  if (auto Dot = tryCombineToDotProduct(Acc, Input, DAG, Subtarget, DL))
     return Dot;
-  if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
+  if (auto WideAdd = tryCombineToWideAdd(Acc, Input, DAG, Subtarget, DL))
     return WideAdd;
-  return DAG.expandPartialReduceAdd(SDLoc(N), N->getOperand(0),
-                                    N->getOperand(1));
+  return SDValue();
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,
@@ -28919,6 +28925,39 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
   return Scatter;
 }
 
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD(SDValue Op,
+                                               SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  SDValue Acc = Op.getOperand(0);
+  SDValue Input = Op.getOperand(1);
+
+  EVT AccVT = Acc.getValueType();
+  EVT InputVT = Input.getValueType();
+
+  unsigned Opcode = Op.getOpcode();
+
+  if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
+    unsigned IndexAdd = 0;
+    // ISD::MUL may have already been lowered, meaning the operands would be in
+    // different positions.
+    if (Input.getOpcode() != ISD::MUL)
+      IndexAdd = 1;
+    auto A = Input.getOperand(IndexAdd);
+    auto B = Input.getOperand(IndexAdd + 1);
+
+    unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SADD ? AArch64ISD::SDOT
+                                                            : AArch64ISD::UDOT;
+    return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
+  }
+  bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
+  unsigned BottomOpcode =
+      InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+  unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+  auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
+  return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
+}
+
 SDValue
 AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
                                                     SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index cb0b9e965277aa..20eb9232c08091 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1171,6 +1171,7 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerPARTIAL_REDUCE_ADD(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;



More information about the llvm-commits mailing list